Close #15: Implement Feature Selection functions (chi2, snr)

Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/fad2941f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/fad2941f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/fad2941f

Branch: refs/heads/master
Commit: fad2941fdb0309dd6fbf22f2f936cbf0003b1c4a
Parents: 518e232
Author: amaya <amaya...@users.noreply.github.com>
Authored: Tue Dec 13 16:46:07 2016 +0900
Committer: myui <yuin...@gmail.com>
Committed: Tue Dec 13 16:46:07 2016 +0900

----------------------------------------------------------------------
 .../hivemall/ftvec/selection/ChiSquareUDF.java  | 173 +++++++++
 .../ftvec/selection/SignalNoiseRatioUDAF.java   | 370 +++++++++++++++++++
 .../hivemall/tools/array/SelectKBestUDF.java    | 163 ++++++++
 .../tools/matrix/TransposeAndDotUDAF.java       | 222 +++++++++++
 .../java/hivemall/utils/hadoop/HiveUtils.java   |  22 +-
 .../hivemall/utils/hadoop/WritableUtils.java    |  16 +
 .../java/hivemall/utils/lang/Preconditions.java |  30 ++
 .../java/hivemall/utils/math/StatsUtils.java    |  91 +++++
 .../ftvec/selection/ChiSquareUDFTest.java       |  82 ++++
 .../selection/SignalNoiseRatioUDAFTest.java     | 342 +++++++++++++++++
 .../tools/array/SelectKBeatUDFTest.java         |  69 ++++
 .../tools/matrix/TransposeAndDotUDAFTest.java   |  59 +++
 .../hivemall/utils/lang/PreconditionsTest.java  |  37 ++
 docs/gitbook/SUMMARY.md                         |   2 +
 .../gitbook/ft_engineering/feature_selection.md | 155 ++++++++
 docs/gitbook/ft_engineering/quantify.md         |   2 +-
 resources/ddl/define-all-as-permanent.hive      |  20 +
 resources/ddl/define-all.hive                   |  20 +
 resources/ddl/define-all.spark                  |  50 ++-
 resources/ddl/define-udfs.td.hql                |   4 +
 .../apache/spark/sql/hive/GroupedDataEx.scala   |  29 +-
 .../org/apache/spark/sql/hive/HivemallOps.scala |  18 +
 .../spark/sql/hive/HivemallOpsSuite.scala       | 105 +++++-
 .../spark/sql/hive/HivemallGroupedDataset.scala |  26 ++
 .../org/apache/spark/sql/hive/HivemallOps.scala |  20 +
 .../spark/sql/hive/HivemallOpsSuite.scala       | 100 +++++
 26 files changed, 2203 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java 
b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
new file mode 100644
index 0000000..9ada4e5
--- /dev/null
+++ b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.ftvec.selection;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.hadoop.WritableUtils;
+import hivemall.utils.lang.Preconditions;
+import hivemall.utils.math.StatsUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.UDFType;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+
+@Description(name = "chi2",
+        value = "_FUNC_(array<array<number>> observed, array<array<number>> 
expected)"
+                + " - Returns chi2_val and p_val of each columns as 
<array<double>, array<double>>")
+@UDFType(deterministic = true, stateful = false)
+public final class ChiSquareUDF extends GenericUDF {
+
+    private ListObjectInspector observedOI;
+    private ListObjectInspector observedRowOI;
+    private PrimitiveObjectInspector observedElOI;
+    private ListObjectInspector expectedOI;
+    private ListObjectInspector expectedRowOI;
+    private PrimitiveObjectInspector expectedElOI;
+
+    private int nFeatures = -1;
+    private double[] observedRow = null; // to reuse
+    private double[] expectedRow = null; // to reuse
+    private double[][] observed = null; // shape = (#features, #classes)
+    private double[][] expected = null; // shape = (#features, #classes)
+    
+    private List<DoubleWritable>[] result;
+
+    @SuppressWarnings("unchecked")
+    @Override
+    public ObjectInspector initialize(ObjectInspector[] OIs) throws 
UDFArgumentException {
+        if (OIs.length != 2) {
+            throw new UDFArgumentLengthException("Specify two arguments: " + 
OIs.length);
+        }
+        if (!HiveUtils.isNumberListListOI(OIs[0])) {
+            throw new UDFArgumentTypeException(0,
+                "Only array<array<number>> type argument is acceptable but " + 
OIs[0].getTypeName()
+                        + " was passed as `observed`");
+        }
+        if (!HiveUtils.isNumberListListOI(OIs[1])) {
+            throw new UDFArgumentTypeException(1,
+                "Only array<array<number>> type argument is acceptable but " + 
OIs[1].getTypeName()
+                        + " was passed as `expected`");
+        }
+
+        this.observedOI = HiveUtils.asListOI(OIs[1]);
+        this.observedRowOI = 
HiveUtils.asListOI(observedOI.getListElementObjectInspector());
+        this.observedElOI = 
HiveUtils.asDoubleCompatibleOI(observedRowOI.getListElementObjectInspector());
+        this.expectedOI = HiveUtils.asListOI(OIs[0]);
+        this.expectedRowOI = 
HiveUtils.asListOI(expectedOI.getListElementObjectInspector());
+        this.expectedElOI = 
HiveUtils.asDoubleCompatibleOI(expectedRowOI.getListElementObjectInspector());  
      
+        this.result = new List[2];
+        
+        List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+        
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+        
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+        
+        return ObjectInspectorFactory.getStandardStructObjectInspector(
+            Arrays.asList("chi2", "pvalue"), fieldOIs);
+    }
+
+    @Override
+    public List<DoubleWritable>[] evaluate(DeferredObject[] dObj) throws 
HiveException {
+        List<?> observedObj = observedOI.getList(dObj[0].get()); // shape = 
(#classes, #features)
+        List<?> expectedObj = expectedOI.getList(dObj[1].get()); // shape = 
(#classes, #features)
+
+        if (observedObj == null || expectedObj == null) {
+            return null;
+        }
+
+        final int nClasses = observedObj.size();
+        Preconditions.checkArgument(nClasses == expectedObj.size(), 
UDFArgumentException.class);
+
+        // explode and transpose matrix
+        for (int i = 0; i < nClasses; i++) {
+            Object observedObjRow = observedObj.get(i);
+            Object expectedObjRow = expectedObj.get(i);
+
+            Preconditions.checkNotNull(observedObjRow, 
UDFArgumentException.class);
+            Preconditions.checkNotNull(expectedObjRow, 
UDFArgumentException.class);
+
+            if (observedRow == null) {
+                observedRow = HiveUtils.asDoubleArray(observedObjRow, 
observedRowOI, observedElOI,
+                    false);
+                expectedRow = HiveUtils.asDoubleArray(expectedObjRow, 
expectedRowOI, expectedElOI,
+                    false);
+                nFeatures = observedRow.length;
+                observed = new double[nFeatures][nClasses];
+                expected = new double[nFeatures][nClasses];
+            } else {
+                HiveUtils.toDoubleArray(observedObjRow, observedRowOI, 
observedElOI, observedRow,
+                    false);
+                HiveUtils.toDoubleArray(expectedObjRow, expectedRowOI, 
expectedElOI, expectedRow,
+                    false);
+            }
+
+            for (int j = 0; j < nFeatures; j++) {
+                observed[j][i] = observedRow[j];
+                expected[j][i] = expectedRow[j];
+            }
+        }
+
+        Map.Entry<double[], double[]> chi2 = StatsUtils.chiSquare(observed, 
expected);
+        
+        result[0] = WritableUtils.toWritableList(chi2.getKey(), result[0]);
+        result[1] = WritableUtils.toWritableList(chi2.getValue(), result[1]);
+        return result;
+    }
+
+    @Override
+    public void close() throws IOException {
+        // help GC
+        this.observedRow = null;
+        this.expectedRow = null;
+        this.observed = null;
+        this.expected = null;
+        this.result = null;
+    }
+
+    @Override
+    public String getDisplayString(String[] children) {
+        final StringBuilder sb = new StringBuilder();
+        sb.append("chi2");
+        sb.append("(");
+        if (children.length > 0) {
+            sb.append(children[0]);
+            for (int i = 1; i < children.length; i++) {
+                sb.append(", ");
+                sb.append(children[i]);
+            }
+        }
+        sb.append(")");
+        return sb.toString();
+    }
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java 
b/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
new file mode 100644
index 0000000..da0de59
--- /dev/null
+++ b/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
@@ -0,0 +1,370 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.ftvec.selection;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.hadoop.WritableUtils;
+import hivemall.utils.lang.Preconditions;
+import hivemall.utils.lang.SizeOf;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import javax.annotation.Nonnull;
+
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructField;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+
+@Description(name = "snr", value = "_FUNC_(array<number> features, array<int> 
one-hot class label)"
+        + " - Returns Signal Noise Ratio for each feature as array<double>")
+public class SignalNoiseRatioUDAF extends AbstractGenericUDAFResolver {
+
+    @Override
+    public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info)
+            throws SemanticException {
+        final ObjectInspector[] OIs = info.getParameterObjectInspectors();
+
+        if (OIs.length != 2) {
+            throw new UDFArgumentLengthException("Specify two arguments: " + 
OIs.length);
+        }
+        if (!HiveUtils.isNumberListOI(OIs[0])) {
+            throw new UDFArgumentTypeException(0,
+                "Only array<number> type argument is acceptable but " + 
OIs[0].getTypeName()
+                        + " was passed as `features`");
+        }
+        if (!HiveUtils.isListOI(OIs[1])
+                || !HiveUtils.isIntegerOI(((ListObjectInspector) 
OIs[1]).getListElementObjectInspector())) {
+            throw new UDFArgumentTypeException(1,
+                "Only array<int> type argument is acceptable but " + 
OIs[1].getTypeName()
+                        + " was passed as `labels`");
+        }
+
+        return new SignalNoiseRatioUDAFEvaluator();
+    }
+
+    static class SignalNoiseRatioUDAFEvaluator extends GenericUDAFEvaluator {
+        // PARTIAL1 and COMPLETE
+        private ListObjectInspector featuresOI;
+        private PrimitiveObjectInspector featureOI;
+        private ListObjectInspector labelsOI;
+        private PrimitiveObjectInspector labelOI;
+
+        // PARTIAL2 and FINAL
+        private StructObjectInspector structOI;
+        private StructField countsField, meansField, variancesField;
+        private ListObjectInspector countsOI;
+        private LongObjectInspector countOI;
+        private ListObjectInspector meansOI;
+        private ListObjectInspector meanListOI;
+        private DoubleObjectInspector meanElemOI;
+        private ListObjectInspector variancesOI;
+        private ListObjectInspector varianceListOI;
+        private DoubleObjectInspector varianceElemOI;
+
+        @AggregationType(estimable = true)
+        static class SignalNoiseRatioAggregationBuffer extends 
AbstractAggregationBuffer {
+            long[] counts;
+            double[][] means;
+            double[][] variances;
+
+            @Override
+            public int estimate() {
+                return counts == null ? 0 : SizeOf.LONG * counts.length + 
SizeOf.DOUBLE
+                        * means.length * means[0].length + SizeOf.DOUBLE * 
variances.length
+                        * variances[0].length;
+            }
+
+            public void init(int nClasses, int nFeatures) {
+                this.counts = new long[nClasses];
+                this.means = new double[nClasses][nFeatures];
+                this.variances = new double[nClasses][nFeatures];
+            }
+
+            public void reset() {
+                if (counts != null) {
+                    Arrays.fill(counts, 0);
+                    for (double[] mean : means) {
+                        Arrays.fill(mean, 0.d);
+                    }
+                    for (double[] variance : variances) {
+                        Arrays.fill(variance, 0.d);
+                    }
+                }
+            }
+        }
+
+        @Override
+        public ObjectInspector init(Mode mode, ObjectInspector[] OIs) throws 
HiveException {
+            super.init(mode, OIs);
+
+            // initialize input
+            if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from 
original data
+                this.featuresOI = HiveUtils.asListOI(OIs[0]);
+                this.featureOI = 
HiveUtils.asDoubleCompatibleOI(featuresOI.getListElementObjectInspector());
+                this.labelsOI = HiveUtils.asListOI(OIs[1]);
+                this.labelOI = 
HiveUtils.asIntegerOI(labelsOI.getListElementObjectInspector());
+            } else {// from partial aggregation
+                this.structOI = (StructObjectInspector) OIs[0];
+                this.countsField = structOI.getStructFieldRef("counts");
+                this.countsOI = 
HiveUtils.asListOI(countsField.getFieldObjectInspector());
+                this.countOI = 
HiveUtils.asLongOI(countsOI.getListElementObjectInspector());
+                this.meansField = structOI.getStructFieldRef("means");
+                this.meansOI = 
HiveUtils.asListOI(meansField.getFieldObjectInspector());
+                this.meanListOI = 
HiveUtils.asListOI(meansOI.getListElementObjectInspector());
+                this.meanElemOI = 
HiveUtils.asDoubleOI(meanListOI.getListElementObjectInspector());
+                this.variancesField = structOI.getStructFieldRef("variances");
+                this.variancesOI = 
HiveUtils.asListOI(variancesField.getFieldObjectInspector());
+                this.varianceListOI = 
HiveUtils.asListOI(variancesOI.getListElementObjectInspector());
+                this.varianceElemOI = 
HiveUtils.asDoubleOI(varianceListOI.getListElementObjectInspector());
+            }
+
+            // initialize output
+            if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// 
terminatePartial
+                List<ObjectInspector> fieldOIs = new 
ArrayList<ObjectInspector>();
+                
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableLongObjectInspector));
+                
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)));
+                
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)));
+                return ObjectInspectorFactory.getStandardStructObjectInspector(
+                    Arrays.asList("counts", "means", "variances"), fieldOIs);
+            } else {// terminate
+                return 
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+            }
+        }
+
+        @Override
+        public AbstractAggregationBuffer getNewAggregationBuffer() throws 
HiveException {
+            SignalNoiseRatioAggregationBuffer myAgg = new 
SignalNoiseRatioAggregationBuffer();
+            reset(myAgg);
+            return myAgg;
+        }
+
+        @Override
+        public void reset(@SuppressWarnings("deprecation") AggregationBuffer 
agg)
+                throws HiveException {
+            SignalNoiseRatioAggregationBuffer myAgg = 
(SignalNoiseRatioAggregationBuffer) agg;
+            myAgg.reset();
+        }
+
+        @Override
+        public void iterate(@SuppressWarnings("deprecation") AggregationBuffer 
agg,
+                Object[] parameters) throws HiveException {
+            final Object featuresObj = parameters[0];
+            final Object labelsObj = parameters[1];
+
+            Preconditions.checkNotNull(featuresObj);
+            Preconditions.checkNotNull(labelsObj);
+
+            final SignalNoiseRatioAggregationBuffer myAgg = 
(SignalNoiseRatioAggregationBuffer) agg;
+
+            final List<?> labels = labelsOI.getList(labelsObj);
+            final int nClasses = labels.size();
+            Preconditions.checkArgument(nClasses >= 2, 
UDFArgumentException.class);
+
+            final List<?> features = featuresOI.getList(featuresObj);
+            final int nFeatures = features.size();
+            Preconditions.checkArgument(nFeatures >= 1, 
UDFArgumentException.class);
+
+            if (myAgg.counts == null) {
+                myAgg.init(nClasses, nFeatures);
+            } else {
+                Preconditions.checkArgument(nClasses == myAgg.counts.length,
+                    UDFArgumentException.class);
+                Preconditions.checkArgument(nFeatures == myAgg.means[0].length,
+                    UDFArgumentException.class);
+            }
+
+            // incrementally calculates means and variance
+            // 
http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf
+            final int clazz = hotIndex(labels, labelOI);
+            final long n = myAgg.counts[clazz];
+            myAgg.counts[clazz]++;
+            for (int i = 0; i < nFeatures; i++) {
+                final double x = 
PrimitiveObjectInspectorUtils.getDouble(features.get(i), featureOI);
+                final double meanN = myAgg.means[clazz][i];
+                final double varianceN = myAgg.variances[clazz][i];
+                myAgg.means[clazz][i] = (n * meanN + x) / (n + 1.d);
+                myAgg.variances[clazz][i] = (n * varianceN + (x - meanN)
+                        * (x - myAgg.means[clazz][i]))
+                        / (n + 1.d);
+            }
+        }
+
+        private static int hotIndex(@Nonnull List<?> labels, 
PrimitiveObjectInspector labelOI)
+                throws UDFArgumentException {
+            final int nClasses = labels.size();
+
+            int clazz = -1;
+            for (int i = 0; i < nClasses; i++) {
+                final int label = 
PrimitiveObjectInspectorUtils.getInt(labels.get(i), labelOI);
+                if (label == 1) {// assumes one hot encoding 
+                    if (clazz != -1) {
+                        throw new UDFArgumentException(
+                            "Specify one-hot vectorized array. Multiple hot 
elements found.");
+                    }
+                    clazz = i;
+                } else {
+                    if (label != 0) {
+                        throw new UDFArgumentException(
+                            "Assumed one-hot encoding (0/1) but found an 
invalid label: " + label);
+                    }
+                }
+            }
+            if (clazz == -1) {
+                throw new UDFArgumentException(
+                    "Specify one-hot vectorized array for label. Hot element 
not found.");
+            }
+            return clazz;
+        }
+
+        @Override
+        public void merge(@SuppressWarnings("deprecation") AggregationBuffer 
agg, Object other)
+                throws HiveException {
+            if (other == null) {
+                return;
+            }
+
+            final SignalNoiseRatioAggregationBuffer myAgg = 
(SignalNoiseRatioAggregationBuffer) agg;
+
+            final List<?> counts = 
countsOI.getList(structOI.getStructFieldData(other, countsField));
+            final List<?> means = 
meansOI.getList(structOI.getStructFieldData(other, meansField));
+            final List<?> variances = 
variancesOI.getList(structOI.getStructFieldData(other,
+                variancesField));
+
+            final int nClasses = counts.size();
+            final int nFeatures = meanListOI.getListLength(means.get(0));
+            if (myAgg.counts == null) {
+                myAgg.init(nClasses, nFeatures);
+            }
+
+            for (int i = 0; i < nClasses; i++) {
+                final long n = myAgg.counts[i];
+                final long cnt = 
PrimitiveObjectInspectorUtils.getLong(counts.get(i), countOI);
+
+                // no need to merge class `i`
+                if (cnt == 0) {
+                    continue;
+                }
+
+                final List<?> mean = meanListOI.getList(means.get(i));
+                final List<?> variance = 
varianceListOI.getList(variances.get(i));
+
+                myAgg.counts[i] += cnt;
+                for (int j = 0; j < nFeatures; j++) {
+                    final double meanN = myAgg.means[i][j];
+                    final double meanM = 
PrimitiveObjectInspectorUtils.getDouble(mean.get(j),
+                        meanElemOI);
+                    final double varianceN = myAgg.variances[i][j];
+                    final double varianceM = 
PrimitiveObjectInspectorUtils.getDouble(
+                        variance.get(j), varianceElemOI);
+
+                    if (n == 0) {// only assign `other` into `myAgg`
+                        myAgg.means[i][j] = meanM;
+                        myAgg.variances[i][j] = varianceM;
+                    } else {
+                        // merge by Chan's method
+                        // 
http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf
+                        myAgg.means[i][j] = (n * meanN + cnt * meanM) / 
(double) (n + cnt);
+                        myAgg.variances[i][j] = (varianceN * (n - 1) + 
varianceM * (cnt - 1) + Math.pow(
+                            meanN - meanM, 2) * n * cnt / (n + cnt))
+                                / (n + cnt - 1);
+                    }
+                }
+            }
+        }
+
+        @Override
+        public Object terminatePartial(@SuppressWarnings("deprecation") 
AggregationBuffer agg)
+                throws HiveException {
+            final SignalNoiseRatioAggregationBuffer myAgg = 
(SignalNoiseRatioAggregationBuffer) agg;
+
+            final Object[] partialResult = new Object[3];
+            partialResult[0] = WritableUtils.toWritableList(myAgg.counts);
+            final List<List<DoubleWritable>> means = new 
ArrayList<List<DoubleWritable>>();
+            for (double[] mean : myAgg.means) {
+                means.add(WritableUtils.toWritableList(mean));
+            }
+            partialResult[1] = means;
+            final List<List<DoubleWritable>> variances = new 
ArrayList<List<DoubleWritable>>();
+            for (double[] variance : myAgg.variances) {
+                variances.add(WritableUtils.toWritableList(variance));
+            }
+            partialResult[2] = variances;
+            return partialResult;
+        }
+
+        @Override
+        public Object terminate(@SuppressWarnings("deprecation") 
AggregationBuffer agg)
+                throws HiveException {
+            final SignalNoiseRatioAggregationBuffer myAgg = 
(SignalNoiseRatioAggregationBuffer) agg;
+
+            final int nClasses = myAgg.counts.length;
+            final int nFeatures = myAgg.means[0].length;
+
+            // compute SNR among classes for each feature
+            final double[] result = new double[nFeatures];
+            final double[] sds = new double[nClasses]; // for memorization
+            for (int i = 0; i < nFeatures; i++) {
+                sds[0] = Math.sqrt(myAgg.variances[0][i]);
+                for (int j = 1; j < nClasses; j++) {
+                    sds[j] = Math.sqrt(myAgg.variances[j][i]);
+                    // `ns[j] == 0` means no feature entry belongs to class 
`j`. Then, skip the entry.
+                    if (myAgg.counts[j] == 0) {
+                        continue;
+                    }
+                    for (int k = 0; k < j; k++) {
+                        // avoid comparing between classes having only single 
entry
+                        if (myAgg.counts[k] == 0 || (myAgg.counts[j] == 1 && 
myAgg.counts[k] == 1)) {
+                            continue;
+                        }
+
+                        // SUM(snr) GROUP BY feature
+                        final double snr = Math.abs(myAgg.means[j][i] - 
myAgg.means[k][i])
+                                / (sds[j] + sds[k]);
+                        // if `NaN`(when diff between means and both sds are 
zero, IOW, all related values are equal),
+                        // regard feature `i` as meaningless between class `j` 
and `k`. So, skip the entry.
+                        if (!Double.isNaN(snr)) {
+                            result[i] += snr; // accept `Infinity`
+                        }
+                    }
+                }
+            }
+
+            return WritableUtils.toWritableList(result);
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/main/java/hivemall/tools/array/SelectKBestUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/array/SelectKBestUDF.java 
b/core/src/main/java/hivemall/tools/array/SelectKBestUDF.java
new file mode 100644
index 0000000..b363166
--- /dev/null
+++ b/core/src/main/java/hivemall/tools/array/SelectKBestUDF.java
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.tools.array;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Preconditions;
+
+import java.io.IOException;
+import java.util.AbstractMap;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.UDFType;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+
+@Description(name = "select_k_best",
+        value = "_FUNC_(array<number> array, const array<number> importance, 
const int k)"
+                + " - Returns selected top-k elements as array<double>")
+@UDFType(deterministic = true, stateful = false)
+public final class SelectKBestUDF extends GenericUDF {
+
+    private ListObjectInspector featuresOI;
+    private PrimitiveObjectInspector featureOI;
+    private ListObjectInspector importanceListOI;
+    private PrimitiveObjectInspector importanceElemOI;
+
+    private int _k;
+    private List<DoubleWritable> _result;
+    private int[] _topKIndices;
+
+    @Override
+    public ObjectInspector initialize(ObjectInspector[] OIs) throws 
UDFArgumentException {
+        if (OIs.length != 3) {
+            throw new UDFArgumentLengthException("Specify three arguments: " + 
OIs.length);
+        }
+
+        if (!HiveUtils.isNumberListOI(OIs[0])) {
+            throw new UDFArgumentTypeException(0,
+                "Only array<number> type argument is acceptable but " + 
OIs[0].getTypeName()
+                        + " was passed as `features`");
+        }
+        if (!HiveUtils.isNumberListOI(OIs[1])) {
+            throw new UDFArgumentTypeException(1,
+                "Only array<number> type argument is acceptable but " + 
OIs[1].getTypeName()
+                        + " was passed as `importance_list`");
+        }
+        if (!HiveUtils.isIntegerOI(OIs[2])) {
+            throw new UDFArgumentTypeException(2, "Only int type argument is 
acceptable but "
+                    + OIs[2].getTypeName() + " was passed as `k`");
+        }
+
+        this.featuresOI = HiveUtils.asListOI(OIs[0]);
+        this.featureOI = 
HiveUtils.asDoubleCompatibleOI(featuresOI.getListElementObjectInspector());
+        this.importanceListOI = HiveUtils.asListOI(OIs[1]);
+        this.importanceElemOI = 
HiveUtils.asDoubleCompatibleOI(importanceListOI.getListElementObjectInspector());
+
+        this._k = HiveUtils.getConstInt(OIs[2]);
+        Preconditions.checkArgument(_k >= 1, UDFArgumentException.class);
+        final DoubleWritable[] array = new DoubleWritable[_k];
+        for (int i = 0; i < array.length; i++) {
+            array[i] = new DoubleWritable();
+        }
+        this._result = Arrays.asList(array);
+
+        return 
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+    }
+
+    @Override
+    public List<DoubleWritable> evaluate(DeferredObject[] dObj) throws 
HiveException {
+        final double[] features = HiveUtils.asDoubleArray(dObj[0].get(), 
featuresOI, featureOI);
+        final double[] importanceList = HiveUtils.asDoubleArray(dObj[1].get(), 
importanceListOI,
+            importanceElemOI);
+
+        Preconditions.checkNotNull(features, UDFArgumentException.class);
+        Preconditions.checkNotNull(importanceList, UDFArgumentException.class);
+        Preconditions.checkArgument(features.length == importanceList.length,
+            UDFArgumentException.class);
+        Preconditions.checkArgument(features.length >= _k, 
UDFArgumentException.class);
+
+        int[] topKIndices = _topKIndices;
+        if (topKIndices == null) {
+            final List<Map.Entry<Integer, Double>> list = new 
ArrayList<Map.Entry<Integer, Double>>();
+            for (int i = 0; i < importanceList.length; i++) {
+                list.add(new AbstractMap.SimpleEntry<Integer, Double>(i, 
importanceList[i]));
+            }
+            Collections.sort(list, new Comparator<Map.Entry<Integer, 
Double>>() {
+                @Override
+                public int compare(Map.Entry<Integer, Double> o1, 
Map.Entry<Integer, Double> o2) {
+                    return o1.getValue() > o2.getValue() ? -1 : 1;
+                }
+            });
+
+            topKIndices = new int[_k];
+            for (int i = 0; i < topKIndices.length; i++) {
+                topKIndices[i] = list.get(i).getKey();
+            }
+            this._topKIndices = topKIndices;
+        }
+
+        final List<DoubleWritable> result = _result;
+        for (int i = 0; i < topKIndices.length; i++) {
+            int idx = topKIndices[i];
+            DoubleWritable d = result.get(i);
+            double f = features[idx];
+            d.set(f);
+        }
+        return result;
+    }
+
+    @Override
+    public void close() throws IOException {
+        // help GC
+        this._result = null;
+        this._topKIndices = null;
+    }
+
+    @Override
+    public String getDisplayString(String[] children) {
+        final StringBuilder sb = new StringBuilder();
+        sb.append("select_k_best");
+        sb.append("(");
+        if (children.length > 0) {
+            sb.append(children[0]);
+            for (int i = 1; i < children.length; i++) {
+                sb.append(", ");
+                sb.append(children[i]);
+            }
+        }
+        sb.append(")");
+        return sb.toString();
+    }
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java 
b/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
new file mode 100644
index 0000000..440bbe6
--- /dev/null
+++ b/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.tools.matrix;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.hadoop.WritableUtils;
+import hivemall.utils.lang.Preconditions;
+import hivemall.utils.lang.SizeOf;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+
+@Description(
+        name = "transpose_and_dot",
+        value = "_FUNC_(array<number> matrix0_row, array<number> matrix1_row)"
+                + " - Returns dot(matrix0.T, matrix1) as array<array<double>>, 
shape = (matrix0.#cols, matrix1.#cols)")
+public final class TransposeAndDotUDAF extends AbstractGenericUDAFResolver {
+
+    @Override
+    public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info)
+            throws SemanticException {
+        ObjectInspector[] OIs = info.getParameterObjectInspectors();
+
+        if (OIs.length != 2) {
+            throw new UDFArgumentLengthException("Specify two arguments.");
+        }
+
+        if (!HiveUtils.isNumberListOI(OIs[0])) {
+            throw new UDFArgumentTypeException(0,
+                "Only array<number> type argument is acceptable but " + 
OIs[0].getTypeName()
+                        + " was passed as `matrix0_row`");
+        }
+
+        if (!HiveUtils.isNumberListOI(OIs[1])) {
+            throw new UDFArgumentTypeException(1,
+                "Only array<number> type argument is acceptable but " + 
OIs[1].getTypeName()
+                        + " was passed as `matrix1_row`");
+        }
+
+        return new TransposeAndDotUDAFEvaluator();
+    }
+
+    static final class TransposeAndDotUDAFEvaluator extends 
GenericUDAFEvaluator {
+        // PARTIAL1 and COMPLETE
+        private ListObjectInspector matrix0RowOI;
+        private PrimitiveObjectInspector matrix0ElOI;
+        private ListObjectInspector matrix1RowOI;
+        private PrimitiveObjectInspector matrix1ElOI;
+
+        // PARTIAL2 and FINAL
+        private ListObjectInspector aggMatrixOI;
+        private ListObjectInspector aggMatrixRowOI;
+        private DoubleObjectInspector aggMatrixElOI;
+
+        private double[] matrix0Row;
+        private double[] matrix1Row;
+
+        @AggregationType(estimable = true)
+        static class TransposeAndDotAggregationBuffer extends 
AbstractAggregationBuffer {
+            double[][] aggMatrix;
+
+            @Override
+            public int estimate() {
+                return aggMatrix != null ? aggMatrix.length * 
aggMatrix[0].length * SizeOf.DOUBLE
+                        : 0;
+            }
+
+            public void init(int n, int m) {
+                this.aggMatrix = new double[n][m];
+            }
+
+            public void reset() {
+                if (aggMatrix != null) {
+                    for (double[] row : aggMatrix) {
+                        Arrays.fill(row, 0.d);
+                    }
+                }
+            }
+        }
+
+        @Override
+        public ObjectInspector init(Mode mode, ObjectInspector[] OIs) throws 
HiveException {
+            super.init(mode, OIs);
+
+            if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {
+                this.matrix0RowOI = HiveUtils.asListOI(OIs[0]);
+                this.matrix0ElOI = 
HiveUtils.asDoubleCompatibleOI(matrix0RowOI.getListElementObjectInspector());
+                this.matrix1RowOI = HiveUtils.asListOI(OIs[1]);
+                this.matrix1ElOI = 
HiveUtils.asDoubleCompatibleOI(matrix1RowOI.getListElementObjectInspector());
+            } else {
+                this.aggMatrixOI = HiveUtils.asListOI(OIs[0]);
+                this.aggMatrixRowOI = 
HiveUtils.asListOI(aggMatrixOI.getListElementObjectInspector());
+                this.aggMatrixElOI = 
HiveUtils.asDoubleOI(aggMatrixRowOI.getListElementObjectInspector());
+            }
+
+            return 
ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+        }
+
+        @Override
+        public AbstractAggregationBuffer getNewAggregationBuffer() throws 
HiveException {
+            TransposeAndDotAggregationBuffer myAgg = new 
TransposeAndDotAggregationBuffer();
+            reset(myAgg);
+            return myAgg;
+        }
+
+        @Override
+        public void reset(@SuppressWarnings("deprecation") AggregationBuffer 
agg)
+                throws HiveException {
+            TransposeAndDotAggregationBuffer myAgg = 
(TransposeAndDotAggregationBuffer) agg;
+            myAgg.reset();
+        }
+
+        @Override
+        public void iterate(@SuppressWarnings("deprecation") AggregationBuffer 
agg,
+                Object[] parameters) throws HiveException {
+            final Object matrix0RowObj = parameters[0];
+            final Object matrix1RowObj = parameters[1];
+
+            Preconditions.checkNotNull(matrix0RowObj, 
UDFArgumentException.class);
+            Preconditions.checkNotNull(matrix1RowObj, 
UDFArgumentException.class);
+
+            final TransposeAndDotAggregationBuffer myAgg = 
(TransposeAndDotAggregationBuffer) agg;
+
+            if (matrix0Row == null) {
+                matrix0Row = new 
double[matrix0RowOI.getListLength(matrix0RowObj)];
+            }
+            if (matrix1Row == null) {
+                matrix1Row = new 
double[matrix1RowOI.getListLength(matrix1RowObj)];
+            }
+
+            HiveUtils.toDoubleArray(matrix0RowObj, matrix0RowOI, matrix0ElOI, 
matrix0Row, false);
+            HiveUtils.toDoubleArray(matrix1RowObj, matrix1RowOI, matrix1ElOI, 
matrix1Row, false);
+
+            if (myAgg.aggMatrix == null) {
+                myAgg.init(matrix0Row.length, matrix1Row.length);
+            }
+
+            for (int i = 0; i < matrix0Row.length; i++) {
+                for (int j = 0; j < matrix1Row.length; j++) {
+                    myAgg.aggMatrix[i][j] += matrix0Row[i] * matrix1Row[j];
+                }
+            }
+        }
+
+        @Override
+        public void merge(@SuppressWarnings("deprecation") AggregationBuffer 
agg, Object other)
+                throws HiveException {
+            if (other == null) {
+                return;
+            }
+
+            final TransposeAndDotAggregationBuffer myAgg = 
(TransposeAndDotAggregationBuffer) agg;
+
+            final List<?> matrix = aggMatrixOI.getList(other);
+            final int n = matrix.size();
+            final double[] row = new 
double[aggMatrixRowOI.getListLength(matrix.get(0))];
+            for (int i = 0; i < n; i++) {
+                HiveUtils.toDoubleArray(matrix.get(i), aggMatrixRowOI, 
aggMatrixElOI, row, false);
+
+                if (myAgg.aggMatrix == null) {
+                    myAgg.init(n, row.length);
+                }
+
+                for (int j = 0; j < row.length; j++) {
+                    myAgg.aggMatrix[i][j] += row[j];
+                }
+            }
+        }
+
+        @Override
+        public Object terminatePartial(@SuppressWarnings("deprecation") 
AggregationBuffer agg)
+                throws HiveException {
+            return terminate(agg);
+        }
+
+        @Override
+        public Object terminate(@SuppressWarnings("deprecation") 
AggregationBuffer agg)
+                throws HiveException {
+            final TransposeAndDotAggregationBuffer myAgg = 
(TransposeAndDotAggregationBuffer) agg;
+
+            final List<List<DoubleWritable>> result = new 
ArrayList<List<DoubleWritable>>();
+            for (double[] row : myAgg.aggMatrix) {
+                result.add(WritableUtils.toWritableList(row));
+            }
+            return result;
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java 
b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
index d8b1aef..8188b7a 100644
--- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
@@ -59,6 +59,7 @@ import 
org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector;
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector;
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
@@ -200,8 +201,7 @@ public final class HiveUtils {
         return BOOLEAN_TYPE_NAME.equals(typeName);
     }
 
-    public static boolean isNumberOI(@Nonnull final ObjectInspector argOI)
-            throws UDFArgumentTypeException {
+    public static boolean isNumberOI(@Nonnull final ObjectInspector argOI) {
         if (argOI.getCategory() != Category.PRIMITIVE) {
             return false;
         }
@@ -246,6 +246,16 @@ public final class HiveUtils {
         return oi.getCategory() == Category.MAP;
     }
 
+    public static boolean isNumberListOI(@Nonnull final ObjectInspector oi) {
+        return isListOI(oi)
+                && isNumberOI(((ListObjectInspector) 
oi).getListElementObjectInspector());
+    }
+
+    public static boolean isNumberListListOI(@Nonnull final ObjectInspector 
oi) {
+        return isListOI(oi)
+                && isNumberListOI(((ListObjectInspector) 
oi).getListElementObjectInspector());
+    }
+
     public static boolean isPrimitiveTypeInfo(@Nonnull TypeInfo typeInfo) {
         return typeInfo.getCategory() == ObjectInspector.Category.PRIMITIVE;
     }
@@ -687,6 +697,14 @@ public final class HiveUtils {
         return (LongObjectInspector) argOI;
     }
 
+    public static DoubleObjectInspector asDoubleOI(@Nonnull final 
ObjectInspector argOI)
+            throws UDFArgumentException {
+        if (!DOUBLE_TYPE_NAME.equals(argOI.getTypeName())) {
+            throw new UDFArgumentException("Argument type must be DOUBLE: " + 
argOI.getTypeName());
+        }
+        return (DoubleObjectInspector) argOI;
+    }
+
     public static PrimitiveObjectInspector asIntCompatibleOI(@Nonnull final 
ObjectInspector argOI)
             throws UDFArgumentTypeException {
         if (argOI.getCategory() != Category.PRIMITIVE) {

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/main/java/hivemall/utils/hadoop/WritableUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/hadoop/WritableUtils.java 
b/core/src/main/java/hivemall/utils/hadoop/WritableUtils.java
index a4f2691..a9c7390 100644
--- a/core/src/main/java/hivemall/utils/hadoop/WritableUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/WritableUtils.java
@@ -25,7 +25,9 @@ import java.util.List;
 
 import javax.annotation.CheckForNull;
 import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
 
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
 import org.apache.hadoop.hive.serde2.io.DoubleWritable;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
 import 
org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption;
@@ -142,6 +144,20 @@ public final class WritableUtils {
         return list;
     }
 
+    @Nonnull
+    public static List<DoubleWritable> toWritableList(@Nonnull final double[] 
src,
+            @Nullable List<DoubleWritable> list) throws UDFArgumentException {
+        if (list == null) {
+            return toWritableList(src);
+        }
+
+        Preconditions.checkArgument(src.length == list.size(), 
UDFArgumentException.class);
+        for (int i = 0; i < src.length; i++) {
+            list.set(i, new DoubleWritable(src[i]));
+        }
+        return list;
+    }
+
     public static Text val(final String v) {
         return new Text(v);
     }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/main/java/hivemall/utils/lang/Preconditions.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/lang/Preconditions.java 
b/core/src/main/java/hivemall/utils/lang/Preconditions.java
index 4fa2bdd..9f76bd6 100644
--- a/core/src/main/java/hivemall/utils/lang/Preconditions.java
+++ b/core/src/main/java/hivemall/utils/lang/Preconditions.java
@@ -18,6 +18,7 @@
  */
 package hivemall.utils.lang;
 
+import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
 
 public final class Preconditions {
@@ -31,6 +32,21 @@ public final class Preconditions {
         return reference;
     }
 
+    public static <T, E extends Throwable> T checkNotNull(@Nullable T 
reference,
+            @Nonnull Class<E> clazz) throws E {
+        if (reference == null) {
+            final E throwable;
+            try {
+                throwable = clazz.newInstance();
+            } catch (InstantiationException | IllegalAccessException e) {
+                throw new IllegalStateException(
+                    "Failed to instantiate a class: " + clazz.getName(), e);
+            }
+            throw throwable;
+        }
+        return reference;
+    }
+
     public static <T> T checkNotNull(T reference, @Nullable Object 
errorMessage) {
         if (reference == null) {
             throw new NullPointerException(String.valueOf(errorMessage));
@@ -44,6 +60,20 @@ public final class Preconditions {
         }
     }
 
+    public static <E extends Throwable> void checkArgument(boolean expression,
+            @Nonnull Class<E> clazz) throws E {
+        if (!expression) {
+            final E throwable;
+            try {
+                throwable = clazz.newInstance();
+            } catch (InstantiationException | IllegalAccessException e) {
+                throw new IllegalStateException(
+                    "Failed to instantiate a class: " + clazz.getName(), e);
+            }
+            throw throwable;
+        }
+    }
+
     public static void checkArgument(boolean expression, @Nullable Object 
errorMessage) {
         if (!expression) {
             throw new IllegalArgumentException(String.valueOf(errorMessage));

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/main/java/hivemall/utils/math/StatsUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/StatsUtils.java 
b/core/src/main/java/hivemall/utils/math/StatsUtils.java
index 812f619..599bf51 100644
--- a/core/src/main/java/hivemall/utils/math/StatsUtils.java
+++ b/core/src/main/java/hivemall/utils/math/StatsUtils.java
@@ -22,11 +22,19 @@ import hivemall.utils.lang.Preconditions;
 
 import javax.annotation.Nonnull;
 
+import org.apache.commons.math3.distribution.ChiSquaredDistribution;
+import org.apache.commons.math3.exception.DimensionMismatchException;
+import org.apache.commons.math3.exception.NotPositiveException;
 import org.apache.commons.math3.linear.DecompositionSolver;
 import org.apache.commons.math3.linear.LUDecomposition;
 import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.commons.math3.linear.RealVector;
 import org.apache.commons.math3.linear.SingularValueDecomposition;
+import org.apache.commons.math3.util.FastMath;
+import org.apache.commons.math3.util.MathArrays;
+
+import java.util.AbstractMap;
+import java.util.Map;
 
 public final class StatsUtils {
 
@@ -189,4 +197,87 @@ public final class StatsUtils {
         return 1.d - numerator / denominator;
     }
 
+    /**
+     * @param observed means non-negative vector
+     * @param expected means positive vector
+     * @return chi2 value
+     */
+    public static double chiSquare(@Nonnull final double[] observed,
+            @Nonnull final double[] expected) {
+        if (observed.length < 2) {
+            throw new DimensionMismatchException(observed.length, 2);
+        }
+        if (expected.length != observed.length) {
+            throw new DimensionMismatchException(observed.length, 
expected.length);
+        }
+        MathArrays.checkPositive(expected);
+        for (double d : observed) {
+            if (d < 0.d) {
+                throw new NotPositiveException(d);
+            }
+        }
+
+        double sumObserved = 0.d;
+        double sumExpected = 0.d;
+        for (int i = 0; i < observed.length; i++) {
+            sumObserved += observed[i];
+            sumExpected += expected[i];
+        }
+        double ratio = 1.d;
+        boolean rescale = false;
+        if (FastMath.abs(sumObserved - sumExpected) > 10e-6) {
+            ratio = sumObserved / sumExpected;
+            rescale = true;
+        }
+        double sumSq = 0.d;
+        for (int i = 0; i < observed.length; i++) {
+            if (rescale) {
+                final double dev = observed[i] - ratio * expected[i];
+                sumSq += dev * dev / (ratio * expected[i]);
+            } else {
+                final double dev = observed[i] - expected[i];
+                sumSq += dev * dev / expected[i];
+            }
+        }
+        return sumSq;
+    }
+
+    /**
+     * @param observed means non-negative vector
+     * @param expected means positive vector
+     * @return p value
+     */
+    public static double chiSquareTest(@Nonnull final double[] observed,
+            @Nonnull final double[] expected) {
+        final ChiSquaredDistribution distribution = new ChiSquaredDistribution(
+            expected.length - 1.d);
+        return 1.d - distribution.cumulativeProbability(chiSquare(observed, 
expected));
+    }
+
+    /**
+     * This method offers effective calculation for multiple entries rather 
than calculation
+     * individually
+     * 
+     * @param observeds means non-negative matrix
+     * @param expecteds means positive matrix
+     * @return (chi2 value[], p value[])
+     */
+    public static Map.Entry<double[], double[]> chiSquare(@Nonnull final 
double[][] observeds,
+            @Nonnull final double[][] expecteds) {
+        Preconditions.checkArgument(observeds.length == expecteds.length);
+
+        final int len = expecteds.length;
+        final int lenOfEach = expecteds[0].length;
+
+        final ChiSquaredDistribution distribution = new 
ChiSquaredDistribution(lenOfEach - 1.d);
+
+        final double[] chi2s = new double[len];
+        final double[] ps = new double[len];
+        for (int i = 0; i < len; i++) {
+            chi2s[i] = chiSquare(observeds[i], expecteds[i]);
+            ps[i] = 1.d - distribution.cumulativeProbability(chi2s[i]);
+        }
+
+        return new AbstractMap.SimpleEntry<double[], double[]>(chi2s, ps);
+    }
 }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java 
b/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java
new file mode 100644
index 0000000..fd742bb
--- /dev/null
+++ b/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.ftvec.selection;
+
+import hivemall.utils.hadoop.WritableUtils;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class ChiSquareUDFTest {
+
+    @Test
+    public void testIris() throws Exception {
+        final ChiSquareUDF chi2 = new ChiSquareUDF();
+        final List<List<DoubleWritable>> observed = new 
ArrayList<List<DoubleWritable>>();
+        final List<List<DoubleWritable>> expected = new 
ArrayList<List<DoubleWritable>>();
+        final GenericUDF.DeferredObject[] dObjs = new 
GenericUDF.DeferredObject[] {
+                new GenericUDF.DeferredJavaObject(observed),
+                new GenericUDF.DeferredJavaObject(expected)};
+
+        final double[][] matrix0 = new double[][] {
+                {250.29999999999998, 170.90000000000003, 73.2, 
12.199999999999996},
+                {296.8, 138.50000000000003, 212.99999999999997, 66.3},
+                {329.3999999999999, 148.7, 277.59999999999997, 
101.29999999999998}};
+        final double[][] matrix1 = new double[][] {
+                {292.1666753739119, 152.70000455081467, 187.93333893418327, 
59.93333511948589},
+                {292.1666753739119, 152.70000455081467, 187.93333893418327, 
59.93333511948589},
+                {292.1666753739119, 152.70000455081467, 187.93333893418327, 
59.93333511948589}};
+
+        for (double[] row : matrix0) {
+            observed.add(WritableUtils.toWritableList(row));
+        }
+        for (double[] row : matrix1) {
+            expected.add(WritableUtils.toWritableList(row));
+        }
+
+        chi2.initialize(new ObjectInspector[] {
+                
ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)),
+                
ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector))});
+        final List<DoubleWritable>[] result = chi2.evaluate(dObjs);
+        final double[] result0 = new double[matrix0[0].length];
+        final double[] result1 = new double[matrix0[0].length];
+        for (int i = 0; i < result0.length; i++) {
+            result0[i] = result[0].get(i).get();
+            result1[i] = result[1].get(i).get();
+        }
+
+        // compare results to one of scikit-learn
+        final double[] answer0 = new double[] {10.81782088, 3.59449902, 
116.16984746, 67.24482759};
+        final double[] answer1 = new double[] {4.47651499e-03, 1.65754167e-01, 
5.94344354e-26,
+                2.50017968e-15};
+
+        Assert.assertArrayEquals(answer0, result0, 1e-5);
+        Assert.assertArrayEquals(answer1, result1, 1e-5);
+        chi2.close();
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java 
b/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
new file mode 100644
index 0000000..79570e3
--- /dev/null
+++ b/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
@@ -0,0 +1,342 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.ftvec.selection;
+
+import hivemall.utils.hadoop.WritableUtils;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.IntWritable;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class SignalNoiseRatioUDAFTest {
+
+    @Test
+    public void snrBinaryClass() throws Exception {
+        // this test is based on *subset* of iris data set
+        final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+        final ObjectInspector[] OIs = new ObjectInspector[] {
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+        final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = 
(SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new 
SimpleGenericUDAFParameterInfo(
+            OIs, false, false));
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+        final 
SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer
 agg = 
(SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer)
 evaluator.getNewAggregationBuffer();
+        evaluator.reset(agg);
+
+        final double[][] features = new double[][] { {5.1, 3.5, 1.4, 0.2}, 
{4.9, 3.d, 1.4, 0.2},
+                {4.7, 3.2, 1.3, 0.2}, {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 
1.5},
+                {6.9, 3.1, 4.9, 1.5}};
+
+        final int[][] labels = new int[][] { {1, 0}, {1, 0}, {1, 0}, {0, 1}, 
{0, 1}, {0, 1}};
+
+        for (int i = 0; i < features.length; i++) {
+            final List<IntWritable> labelList = new ArrayList<IntWritable>();
+            for (int label : labels[i]) {
+                labelList.add(new IntWritable(label));
+            }
+            evaluator.iterate(agg, new Object[] 
{WritableUtils.toWritableList(features[i]),
+                    labelList});
+        }
+
+        @SuppressWarnings("unchecked")
+        final List<DoubleWritable> resultObj = (List<DoubleWritable>) 
evaluator.terminate(agg);
+        final int size = resultObj.size();
+        final double[] result = new double[size];
+        for (int i = 0; i < size; i++) {
+            result[i] = resultObj.get(i).get();
+        }
+
+        // compare with result by numpy
+        final double[] answer = new double[] {4.38425236, 0.26390002, 
15.83984511, 26.87005769};
+
+        Assert.assertArrayEquals(answer, result, 1e-5);
+    }
+
+    @Test
+    public void snrMultipleClassNormalCase() throws Exception {
+        // this test is based on *subset* of iris data set
+        final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+        final ObjectInspector[] OIs = new ObjectInspector[] {
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+        final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = 
(SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new 
SimpleGenericUDAFParameterInfo(
+            OIs, false, false));
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+        final 
SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer
 agg = 
(SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer)
 evaluator.getNewAggregationBuffer();
+        evaluator.reset(agg);
+
+        final double[][] features = new double[][] { {5.1, 3.5, 1.4, 0.2}, 
{4.9, 3.d, 1.4, 0.2},
+                {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.3, 3.3, 6.d, 
2.5},
+                {5.8, 2.7, 5.1, 1.9}};
+
+        final int[][] labels = new int[][] { {1, 0, 0}, {1, 0, 0}, {0, 1, 0}, 
{0, 1, 0}, {0, 0, 1},
+                {0, 0, 1}};
+
+        for (int i = 0; i < features.length; i++) {
+            final List<IntWritable> labelList = new ArrayList<IntWritable>();
+            for (int label : labels[i]) {
+                labelList.add(new IntWritable(label));
+            }
+            evaluator.iterate(agg, new Object[] 
{WritableUtils.toWritableList(features[i]),
+                    labelList});
+        }
+
+        @SuppressWarnings("unchecked")
+        final List<DoubleWritable> resultObj = (List<DoubleWritable>) 
evaluator.terminate(agg);
+        final int size = resultObj.size();
+        final double[] result = new double[size];
+        for (int i = 0; i < size; i++) {
+            result[i] = resultObj.get(i).get();
+        }
+
+        // compare with result by scikit-learn
+        final double[] answer = new double[] {8.43181818, 1.32121212, 
42.94949495, 33.80952381};
+
+        Assert.assertArrayEquals(answer, result, 1e-5);
+    }
+
+    @Test
+    public void snrMultipleClassCornerCase0() throws Exception {
+        final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+        final ObjectInspector[] OIs = new ObjectInspector[] {
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+        final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = 
(SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new 
SimpleGenericUDAFParameterInfo(
+            OIs, false, false));
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+        final 
SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer
 agg = 
(SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer)
 evaluator.getNewAggregationBuffer();
+        evaluator.reset(agg);
+
+        // all c0[0] and c1[0] are equal
+        // all c1[1] and c2[1] are equal
+        // all c*[2] are equal
+        // all c*[3] are different
+        final double[][] features = new double[][] { {3.5, 1.4, 0.3, 5.1}, 
{3.5, 1.5, 0.3, 5.2},
+                {3.5, 4.5, 0.3, 7.d}, {3.5, 4.5, 0.3, 6.4}, {3.3, 4.5, 0.3, 
6.3}};
+
+        final int[][] labels = new int[][] { {1, 0, 0}, {1, 0, 0}, // class `0`
+                {0, 1, 0}, {0, 1, 0}, // class `1`
+                {0, 0, 1}}; // class `2`, only single entry
+
+        for (int i = 0; i < features.length; i++) {
+            final List<IntWritable> labelList = new ArrayList<IntWritable>();
+            for (int label : labels[i]) {
+                labelList.add(new IntWritable(label));
+            }
+            evaluator.iterate(agg, new Object[] 
{WritableUtils.toWritableList(features[i]),
+                    labelList});
+        }
+
+        @SuppressWarnings("unchecked")
+        final List<DoubleWritable> resultObj = (List<DoubleWritable>) 
evaluator.terminate(agg);
+        final int size = resultObj.size();
+        final double[] result = new double[size];
+        for (int i = 0; i < size; i++) {
+            result[i] = resultObj.get(i).get();
+        }
+
+        final double[] answer = new double[] {Double.POSITIVE_INFINITY, 
121.99999999999989, 0.d,
+                28.761904761904734};
+
+        Assert.assertArrayEquals(answer, result, 1e-5);
+    }
+
+    @Test
+    public void snrMultipleClassCornerCase1() throws Exception {
+        final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+        final ObjectInspector[] OIs = new ObjectInspector[] {
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+        final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = 
(SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new 
SimpleGenericUDAFParameterInfo(
+            OIs, false, false));
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+        final 
SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer
 agg = 
(SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer)
 evaluator.getNewAggregationBuffer();
+        evaluator.reset(agg);
+
+        final double[][] features = new double[][] { {5.1, 3.5, 1.4, 0.2}, 
{4.9, 3.d, 1.4, 0.2},
+                {7.d, 3.2, 4.7, 1.4}, {6.3, 3.3, 6.d, 2.5}, {6.4, 3.2, 4.5, 
1.5}};
+
+        // has multiple single entries
+        final int[][] labels = new int[][] { {1, 0, 0}, {1, 0, 0}, {1, 0, 0}, 
// class `0`
+                {0, 1, 0}, // class `1`, only single entry
+                {0, 0, 1}}; // class `2`, only single entry
+
+        for (int i = 0; i < features.length; i++) {
+            final List<IntWritable> labelList = new ArrayList<IntWritable>();
+            for (int label : labels[i]) {
+                labelList.add(new IntWritable(label));
+            }
+            evaluator.iterate(agg, new Object[] 
{WritableUtils.toWritableList(features[i]),
+                    labelList});
+        }
+
+        @SuppressWarnings("unchecked")
+        final List<DoubleWritable> resultObj = (List<DoubleWritable>) 
evaluator.terminate(agg);
+        final List<Double> result = new ArrayList<Double>();
+        for (DoubleWritable dw : resultObj) {
+            result.add(dw.get());
+        }
+
+        Assert.assertFalse(result.contains(Double.POSITIVE_INFINITY));
+    }
+
+    @Test
+    public void snrMultipleClassCornerCase2() throws Exception {
+        final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+        final ObjectInspector[] OIs = new ObjectInspector[] {
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+        final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = 
(SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new 
SimpleGenericUDAFParameterInfo(
+            OIs, false, false));
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+        final 
SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer
 agg = 
(SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer)
 evaluator.getNewAggregationBuffer();
+        evaluator.reset(agg);
+
+        // all [0] are equal
+        // all [1] are equal *each class*
+        final double[][] features = new double[][] { {1.d, 1.d, 1.4, 0.2}, 
{1.d, 1.d, 1.4, 0.2},
+                {1.d, 2.d, 4.7, 1.4}, {1.d, 2.d, 4.5, 1.5}, {1.d, 3.d, 6.d, 
2.5},
+                {1.d, 3.d, 5.1, 1.9}};
+
+        final int[][] labels = new int[][] { {1, 0, 0}, {1, 0, 0}, {0, 1, 0}, 
{0, 1, 0}, {0, 0, 1},
+                {0, 0, 1}};
+
+        for (int i = 0; i < features.length; i++) {
+            final List<IntWritable> labelList = new ArrayList<IntWritable>();
+            for (int label : labels[i]) {
+                labelList.add(new IntWritable(label));
+            }
+            evaluator.iterate(agg, new Object[] 
{WritableUtils.toWritableList(features[i]),
+                    labelList});
+        }
+
+        @SuppressWarnings("unchecked")
+        final List<DoubleWritable> resultObj = (List<DoubleWritable>) 
evaluator.terminate(agg);
+        final int size = resultObj.size();
+        final double[] result = new double[size];
+        for (int i = 0; i < size; i++) {
+            result[i] = resultObj.get(i).get();
+        }
+
+        final double[] answer = new double[] {0.d, Double.POSITIVE_INFINITY, 
42.94949495,
+                33.80952381};
+
+        Assert.assertArrayEquals(answer, result, 1e-5);
+    }
+
+    @Test(expected = UDFArgumentException.class)
+    public void shouldFail0() throws Exception {
+        final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+
+        final ObjectInspector[] OIs = new ObjectInspector[] {
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+        final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = 
(SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new 
SimpleGenericUDAFParameterInfo(
+            OIs, false, false));
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+        final 
SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer
 agg = 
(SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer)
 evaluator.getNewAggregationBuffer();
+        evaluator.reset(agg);
+
+        final double[][] featuress = new double[][] { {5.1, 3.5, 1.4, 0.2}, 
{4.9, 3.d, 1.4, 0.2},
+                {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.3, 3.3, 6.d, 
2.5},
+                {5.8, 2.7, 5.1, 1.9}};
+
+        final int[][] labelss = new int[][] { {0, 0, 0}, // cause 
UDFArgumentException
+                {1, 0, 0}, {0, 1, 0}, {0, 1, 0}, {0, 0, 1}, {0, 0, 1}};
+
+        for (int i = 0; i < featuress.length; i++) {
+            final List<IntWritable> labels = new ArrayList<IntWritable>();
+            for (int label : labelss[i]) {
+                labels.add(new IntWritable(label));
+            }
+            evaluator.iterate(agg,
+                new Object[] {WritableUtils.toWritableList(featuress[i]), 
labels});
+        }
+    }
+
+    @Test(expected = UDFArgumentException.class)
+    public void shouldFail1() throws Exception {
+        final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+
+        final ObjectInspector[] OIs = new ObjectInspector[] {
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+        final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = 
(SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new 
SimpleGenericUDAFParameterInfo(
+            OIs, false, false));
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+        final 
SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer
 agg = 
(SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer)
 evaluator.getNewAggregationBuffer();
+        evaluator.reset(agg);
+
+        final double[][] featuress = new double[][] { {5.1, 3.5, 1.4, 0.2},
+                {4.9, 3.d, 1.4}, // cause IllegalArgumentException
+                {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.3, 3.3, 6.d, 
2.5},
+                {5.8, 2.7, 5.1, 1.9}};
+
+        final int[][] labelss = new int[][] { {1, 0, 0}, {1, 0, 0}, {0, 1, 0}, 
{0, 1, 0},
+                {0, 0, 1}, {0, 0, 1}};
+
+        for (int i = 0; i < featuress.length; i++) {
+            final List<IntWritable> labels = new ArrayList<IntWritable>();
+            for (int label : labelss[i]) {
+                labels.add(new IntWritable(label));
+            }
+            evaluator.iterate(agg,
+                new Object[] {WritableUtils.toWritableList(featuress[i]), 
labels});
+        }
+    }
+
+    @Test(expected = UDFArgumentException.class)
+    public void shouldFail2() throws Exception {
+        final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+
+        final ObjectInspector[] OIs = new ObjectInspector[] {
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+        final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = 
(SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new 
SimpleGenericUDAFParameterInfo(
+            OIs, false, false));
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+        final 
SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer
 agg = 
(SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer)
 evaluator.getNewAggregationBuffer();
+        evaluator.reset(agg);
+
+        final double[][] featuress = new double[][] { {5.1, 3.5, 1.4, 0.2}, 
{4.9, 3.d, 1.4, 0.2},
+                {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.3, 3.3, 6.d, 
2.5},
+                {5.8, 2.7, 5.1, 1.9}};
+
+        final int[][] labelss = new int[][] { {1}, {1}, {1}, {1}, {1}, {1}}; 
// cause IllegalArgumentException
+
+        for (int i = 0; i < featuress.length; i++) {
+            final List<IntWritable> labels = new ArrayList<IntWritable>();
+            for (int label : labelss[i]) {
+                labels.add(new IntWritable(label));
+            }
+            evaluator.iterate(agg,
+                new Object[] {WritableUtils.toWritableList(featuress[i]), 
labels});
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/test/java/hivemall/tools/array/SelectKBeatUDFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/tools/array/SelectKBeatUDFTest.java 
b/core/src/test/java/hivemall/tools/array/SelectKBeatUDFTest.java
new file mode 100644
index 0000000..3e3fc12
--- /dev/null
+++ b/core/src/test/java/hivemall/tools/array/SelectKBeatUDFTest.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.tools.array;
+
+import hivemall.utils.hadoop.WritableUtils;
+
+import java.util.List;
+
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class SelectKBeatUDFTest {
+
+    @Test
+    public void test() throws Exception {
+        final SelectKBestUDF selectKBest = new SelectKBestUDF();
+        final int k = 2;
+        final double[] data = new double[] {250.29999999999998, 
170.90000000000003, 73.2,
+                12.199999999999996};
+        final double[] importanceList = new double[] {292.1666753739119, 
152.70000455081467,
+                187.93333893418327, 59.93333511948589};
+
+        final GenericUDF.DeferredObject[] dObjs = new 
GenericUDF.DeferredObject[] {
+                new 
GenericUDF.DeferredJavaObject(WritableUtils.toWritableList(data)),
+                new 
GenericUDF.DeferredJavaObject(WritableUtils.toWritableList(importanceList)),
+                new GenericUDF.DeferredJavaObject(k)};
+
+        selectKBest.initialize(new ObjectInspector[] {
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+                ObjectInspectorUtils.getConstantObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaIntObjectInspector, 
k)});
+        final List<DoubleWritable> resultObj = selectKBest.evaluate(dObjs);
+
+        Assert.assertEquals(resultObj.size(), k);
+
+        final double[] result = new double[k];
+        for (int i = 0; i < k; i++) {
+            result[i] = resultObj.get(i).get();
+        }
+
+        final double[] answer = new double[] {250.29999999999998, 73.2};
+
+        Assert.assertArrayEquals(answer, result, 0.d);
+        selectKBest.close();
+    }
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/test/java/hivemall/tools/matrix/TransposeAndDotUDAFTest.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/hivemall/tools/matrix/TransposeAndDotUDAFTest.java 
b/core/src/test/java/hivemall/tools/matrix/TransposeAndDotUDAFTest.java
new file mode 100644
index 0000000..f705a89
--- /dev/null
+++ b/core/src/test/java/hivemall/tools/matrix/TransposeAndDotUDAFTest.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.tools.matrix;
+
+import hivemall.utils.hadoop.WritableUtils;
+
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TransposeAndDotUDAFTest {
+
+    @Test
+    public void test() throws Exception {
+        final TransposeAndDotUDAF tad = new TransposeAndDotUDAF();
+
+        final double[][] matrix0 = new double[][] { {1, -2}, {-1, 3}};
+        final double[][] matrix1 = new double[][] { {1, 2}, {3, 4}};
+
+        final ObjectInspector[] OIs = new ObjectInspector[] {
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)};
+        final GenericUDAFEvaluator evaluator = tad.getEvaluator(new 
SimpleGenericUDAFParameterInfo(
+            OIs, false, false));
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+        
TransposeAndDotUDAF.TransposeAndDotUDAFEvaluator.TransposeAndDotAggregationBuffer
 agg = 
(TransposeAndDotUDAF.TransposeAndDotUDAFEvaluator.TransposeAndDotAggregationBuffer)
 evaluator.getNewAggregationBuffer();
+        evaluator.reset(agg);
+        for (int i = 0; i < matrix0.length; i++) {
+            evaluator.iterate(agg, new Object[] 
{WritableUtils.toWritableList(matrix0[i]),
+                    WritableUtils.toWritableList(matrix1[i])});
+        }
+
+        final double[][] answer = new double[][] { {-2.0, -2.0}, {7.0, 8.0}};
+
+        for (int i = 0; i < answer.length; i++) {
+            Assert.assertArrayEquals(answer[i], agg.aggMatrix[i], 0.d);
+        }
+    }
+}


Reply via email to