add snr


Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/22a608ee
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/22a608ee
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/22a608ee

Branch: refs/heads/JIRA-22/pr-385
Commit: 22a608ee1c7239b2953183b5341f80c58b1e7045
Parents: 5088ef3
Author: amaya <g...@sapphire.in.net>
Authored: Mon Sep 26 17:07:55 2016 +0900
Committer: amaya <g...@sapphire.in.net>
Committed: Mon Sep 26 17:15:22 2016 +0900

----------------------------------------------------------------------
 .../ftvec/selection/SignalNoiseRatioUDAF.java   | 327 +++++++++++++++++++
 .../selection/SignalNoiseRatioUDAFTest.java     | 174 ++++++++++
 resources/ddl/define-all-as-permanent.hive      |   3 +
 resources/ddl/define-all.hive                   |   3 +
 resources/ddl/define-all.spark                  |   3 +
 resources/ddl/define-udfs.td.hql                |   1 +
 6 files changed, 511 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/22a608ee/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java 
b/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
new file mode 100644
index 0000000..b7b9126
--- /dev/null
+++ b/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
@@ -0,0 +1,327 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science 
and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *         http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.ftvec.selection;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.hadoop.WritableUtils;
+import hivemall.utils.lang.Preconditions;
+import org.apache.commons.math3.util.FastMath;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructField;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+@Description(name = "snr", value = "_FUNC_(array<number> features, array<int> 
one-hot class label)"
+        + " - Returns SNR values of each feature as array<double>")
+public class SignalNoiseRatioUDAF extends AbstractGenericUDAFResolver {
+    @Override
+    public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info)
+            throws SemanticException {
+        final ObjectInspector[] OIs = info.getParameterObjectInspectors();
+
+        if (OIs.length != 2) {
+            throw new UDFArgumentLengthException("Specify two arguments.");
+        }
+
+        if (!HiveUtils.isNumberListOI(OIs[0])) {
+            throw new UDFArgumentTypeException(0,
+                "Only array<number> type argument is acceptable but " + 
OIs[0].getTypeName()
+                        + " was passed as `features`");
+        }
+
+        if (!HiveUtils.isListOI(OIs[1])
+                || !HiveUtils.isIntegerOI(((ListObjectInspector) 
OIs[1]).getListElementObjectInspector())) {
+            throw new UDFArgumentTypeException(1,
+                "Only array<int> type argument is acceptable but " + 
OIs[1].getTypeName()
+                        + " was passed as `labels`");
+        }
+
+        return new SignalNoiseRatioUDAFEvaluator();
+    }
+
+    static class SignalNoiseRatioUDAFEvaluator extends GenericUDAFEvaluator {
+        // PARTIAL1 and COMPLETE
+        private ListObjectInspector featuresOI;
+        private PrimitiveObjectInspector featureOI;
+        private ListObjectInspector labelsOI;
+        private PrimitiveObjectInspector labelOI;
+
+        // PARTIAL2 and FINAL
+        private StructObjectInspector structOI;
+        private StructField nsField, meanssField, variancessField;
+        private ListObjectInspector nsOI;
+        private LongObjectInspector nOI;
+        private ListObjectInspector meanssOI;
+        private ListObjectInspector meansOI;
+        private DoubleObjectInspector meanOI;
+        private ListObjectInspector variancessOI;
+        private ListObjectInspector variancesOI;
+        private DoubleObjectInspector varianceOI;
+
+        @AggregationType(estimable = true)
+        static class SignalNoiseRatioAggregationBuffer extends 
AbstractAggregationBuffer {
+            long[] ns;
+            double[][] meanss;
+            double[][] variancess;
+
+            @Override
+            public int estimate() {
+                return ns == null ? 0 : 8 * ns.length + 8 * meanss.length * 
meanss[0].length + 8
+                        * variancess.length * variancess[0].length;
+            }
+
+            public void init(int nClasses, int nFeatures) {
+                ns = new long[nClasses];
+                meanss = new double[nClasses][nFeatures];
+                variancess = new double[nClasses][nFeatures];
+            }
+
+            public void reset() {
+                if (ns != null) {
+                    Arrays.fill(ns, 0);
+                    for (double[] means : meanss) {
+                        Arrays.fill(means, 0.d);
+                    }
+                    for (double[] variances : variancess) {
+                        Arrays.fill(variances, 0.d);
+                    }
+                }
+            }
+        }
+
+        @Override
+        public ObjectInspector init(Mode mode, ObjectInspector[] OIs) throws 
HiveException {
+            super.init(mode, OIs);
+
+            if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {
+                featuresOI = HiveUtils.asListOI(OIs[0]);
+                featureOI = 
HiveUtils.asDoubleCompatibleOI(featuresOI.getListElementObjectInspector());
+                labelsOI = HiveUtils.asListOI(OIs[1]);
+                labelOI = 
HiveUtils.asIntegerOI(labelsOI.getListElementObjectInspector());
+            } else {
+                structOI = (StructObjectInspector) OIs[0];
+                nsField = structOI.getStructFieldRef("ns");
+                nsOI = HiveUtils.asListOI(nsField.getFieldObjectInspector());
+                nOI = HiveUtils.asLongOI(nsOI.getListElementObjectInspector());
+                meanssField = structOI.getStructFieldRef("meanss");
+                meanssOI = 
HiveUtils.asListOI(meanssField.getFieldObjectInspector());
+                meansOI = 
HiveUtils.asListOI(meanssOI.getListElementObjectInspector());
+                meanOI = 
HiveUtils.asDoubleOI(meansOI.getListElementObjectInspector());
+                variancessField = structOI.getStructFieldRef("variancess");
+                variancessOI = 
HiveUtils.asListOI(variancessField.getFieldObjectInspector());
+                variancesOI = 
HiveUtils.asListOI(variancessOI.getListElementObjectInspector());
+                varianceOI = 
HiveUtils.asDoubleOI(variancesOI.getListElementObjectInspector());
+            }
+
+            if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {
+                List<ObjectInspector> fieldOIs = new 
ArrayList<ObjectInspector>();
+                
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableLongObjectInspector));
+                
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)));
+                
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)));
+                return ObjectInspectorFactory.getStandardStructObjectInspector(
+                    Arrays.asList("ns", "meanss", "variancess"), fieldOIs);
+            } else {
+                return 
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+            }
+        }
+
+        @Override
+        public AbstractAggregationBuffer getNewAggregationBuffer() throws 
HiveException {
+            final SignalNoiseRatioAggregationBuffer myAgg = new 
SignalNoiseRatioAggregationBuffer();
+            reset(myAgg);
+            return myAgg;
+        }
+
+        @Override
+        public void reset(AggregationBuffer agg) throws HiveException {
+            final SignalNoiseRatioAggregationBuffer myAgg = 
(SignalNoiseRatioAggregationBuffer) agg;
+            myAgg.reset();
+        }
+
+        @Override
+        public void iterate(AggregationBuffer agg, Object[] parameters) throws 
HiveException {
+            final Object featuresObj = parameters[0];
+            final Object labelsObj = parameters[1];
+
+            Preconditions.checkNotNull(featuresObj);
+            Preconditions.checkNotNull(labelsObj);
+
+            final SignalNoiseRatioAggregationBuffer myAgg = 
(SignalNoiseRatioAggregationBuffer) agg;
+
+            // read class
+            final List labels = labelsOI.getList(labelsObj);
+            final int nClasses = labels.size();
+
+            // to calc SNR between classes
+            Preconditions.checkArgument(nClasses >= 2);
+
+            int clazz = -1;
+            for (int i = 0; i < nClasses; i++) {
+                int label = 
PrimitiveObjectInspectorUtils.getInt(labels.get(i), labelOI);
+                if (label == 1 && clazz == -1) {
+                    clazz = i;
+                } else if (label == 1) {
+                    throw new UDFArgumentException(
+                        "Specify one-hot vectorized array. Multiple hot 
elements found.");
+                }
+            }
+            if (clazz == -1) {
+                throw new UDFArgumentException(
+                    "Specify one-hot vectorized array. Hot element not 
found.");
+            }
+
+            final List features = featuresOI.getList(featuresObj);
+            final int nFeatures = features.size();
+
+            Preconditions.checkArgument(nFeatures >= 1);
+
+            if (myAgg.ns == null) {
+                // init
+                myAgg.init(nClasses, nFeatures);
+            } else {
+                Preconditions.checkArgument(nClasses == myAgg.ns.length);
+                Preconditions.checkArgument(nFeatures == 
myAgg.meanss[0].length);
+            }
+
+            // calc incrementally
+            final long n = myAgg.ns[clazz];
+            myAgg.ns[clazz]++;
+            for (int i = 0; i < nFeatures; i++) {
+                final double x = 
PrimitiveObjectInspectorUtils.getDouble(features.get(i), featureOI);
+                final double meanN = myAgg.meanss[clazz][i];
+                final double varianceN = myAgg.variancess[clazz][i];
+                myAgg.meanss[clazz][i] = (n * meanN + x) / (n + 1.d);
+                myAgg.variancess[clazz][i] = (n * varianceN + (x - meanN)
+                        * (x - myAgg.meanss[clazz][i]))
+                        / (n + 1.d);
+            }
+        }
+
+        @Override
+        public void merge(AggregationBuffer agg, Object other) throws 
HiveException {
+            if (other == null) {
+                return;
+            }
+
+            final SignalNoiseRatioAggregationBuffer myAgg = 
(SignalNoiseRatioAggregationBuffer) agg;
+
+            final List ns = nsOI.getList(structOI.getStructFieldData(other, 
nsField));
+            final List meanss = 
meanssOI.getList(structOI.getStructFieldData(other, meanssField));
+            final List variancess = 
variancessOI.getList(structOI.getStructFieldData(other,
+                variancessField));
+
+            final int nClasses = ns.size();
+            final int nFeatures = meansOI.getListLength(meanss.get(0));
+            if (myAgg.ns == null) {
+                // init
+                myAgg.init(nClasses, nFeatures);
+            }
+            for (int i = 0; i < nClasses; i++) {
+                final long n = myAgg.ns[i];
+                final long m = 
PrimitiveObjectInspectorUtils.getLong(ns.get(i), nOI);
+                final List means = meansOI.getList(meanss.get(i));
+                final List variances = variancesOI.getList(variancess.get(i));
+
+                myAgg.ns[i] += m;
+                for (int j = 0; j < nFeatures; j++) {
+                    final double meanN = myAgg.meanss[i][j];
+                    final double meanM = 
PrimitiveObjectInspectorUtils.getDouble(means.get(j),
+                        meanOI);
+                    final double varianceN = myAgg.variancess[i][j];
+                    final double varianceM = 
PrimitiveObjectInspectorUtils.getDouble(
+                        variances.get(j), varianceOI);
+                    myAgg.meanss[i][j] = (n * meanN + m * meanM) / (double) (n 
+ m);
+                    myAgg.variancess[i][j] = (varianceN * (n - 1) + varianceM 
* (m - 1) + FastMath.pow(
+                        meanN - meanM, 2) * n * m / (n + m))
+                            / (n + m - 1);
+                }
+            }
+        }
+
+        @Override
+        public Object terminatePartial(AggregationBuffer agg) throws 
HiveException {
+            final SignalNoiseRatioAggregationBuffer myAgg = 
(SignalNoiseRatioAggregationBuffer) agg;
+
+            final Object[] partialResult = new Object[3];
+            partialResult[0] = WritableUtils.toWritableList(myAgg.ns);
+            final List<List<DoubleWritable>> meanss = new 
ArrayList<List<DoubleWritable>>();
+            for (double[] means : myAgg.meanss) {
+                meanss.add(WritableUtils.toWritableList(means));
+            }
+            partialResult[1] = meanss;
+            final List<List<DoubleWritable>> variancess = new 
ArrayList<List<DoubleWritable>>();
+            for (double[] variances : myAgg.variancess) {
+                variancess.add(WritableUtils.toWritableList(variances));
+            }
+            partialResult[2] = variancess;
+            return partialResult;
+        }
+
+        @Override
+        public Object terminate(AggregationBuffer agg) throws HiveException {
+            final SignalNoiseRatioAggregationBuffer myAgg = 
(SignalNoiseRatioAggregationBuffer) agg;
+
+            final int nClasses = myAgg.ns.length;
+            final int nFeatures = myAgg.meanss[0].length;
+
+            // calc SNR between classes each feature
+            final double[] result = new double[nFeatures];
+            final double[] sds = new double[nClasses]; // memo
+            for (int i = 0; i < nFeatures; i++) {
+                sds[0] = FastMath.sqrt(myAgg.variancess[0][i]);
+                for (int j = 1; j < nClasses; j++) {
+                    sds[j] = FastMath.sqrt(myAgg.variancess[j][i]);
+                    if (Double.isNaN(sds[j])) {
+                        continue;
+                    }
+                    for (int k = 0; k < j; k++) {
+                        if (Double.isNaN(sds[k])) {
+                            continue;
+                        }
+                        result[i] += FastMath.abs(myAgg.meanss[j][i] - 
myAgg.meanss[k][i])
+                                / (sds[j] + sds[k]);
+                    }
+                }
+            }
+
+            // SUM(snr) GROUP BY feature
+            return WritableUtils.toWritableList(result);
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/22a608ee/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java 
b/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
new file mode 100644
index 0000000..4655545
--- /dev/null
+++ b/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
@@ -0,0 +1,174 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2016 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science 
and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *         http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.ftvec.selection;
+
+import hivemall.utils.hadoop.WritableUtils;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.IntWritable;
+import org.junit.Assert;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class SignalNoiseRatioUDAFTest {
+    @Rule
+    public ExpectedException expectedException = ExpectedException.none();
+
+    @Test
+    public void test() throws Exception {
+        final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+        final ObjectInspector[] OIs = new ObjectInspector[] {
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+        final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = 
(SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new 
SimpleGenericUDAFParameterInfo(
+            OIs, false, false));
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+        final 
SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer
 agg = 
(SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer)
 evaluator.getNewAggregationBuffer();
+        evaluator.reset(agg);
+
+        final double[][] featuress = new double[][] { {5.1, 3.5, 1.4, 0.2}, 
{4.9, 3.d, 1.4, 0.2},
+                {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.3, 3.3, 6.d, 
2.5},
+                {5.8, 2.7, 5.1, 1.9}};
+
+        final int[][] labelss = new int[][] { {1, 0, 0}, {1, 0, 0}, {0, 1, 0}, 
{0, 1, 0},
+                {0, 0, 1}, {0, 0, 1}};
+
+        for (int i = 0; i < featuress.length; i++) {
+            final List<IntWritable> labels = new ArrayList<IntWritable>();
+            for (int label : labelss[i]) {
+                labels.add(new IntWritable(label));
+            }
+            evaluator.iterate(agg,
+                new Object[] {WritableUtils.toWritableList(featuress[i]), 
labels});
+        }
+
+        @SuppressWarnings("unchecked")
+        final List<DoubleWritable> resultObj = (ArrayList<DoubleWritable>) 
evaluator.terminate(agg);
+        final int size = resultObj.size();
+        final double[] result = new double[size];
+        for (int i = 0; i < size; i++) {
+            result[i] = resultObj.get(i).get();
+        }
+        final double[] answer = new double[] {8.431818181818192, 
1.3212121212121217,
+                42.94949494949499, 33.80952380952378};
+        Assert.assertArrayEquals(answer, result, 0.d);
+    }
+
+    @Test
+    public void shouldFail0() throws Exception {
+        expectedException.expect(UDFArgumentException.class);
+
+        final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+        final ObjectInspector[] OIs = new ObjectInspector[] {
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+        final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = 
(SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new 
SimpleGenericUDAFParameterInfo(
+            OIs, false, false));
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+        final 
SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer
 agg = 
(SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer)
 evaluator.getNewAggregationBuffer();
+        evaluator.reset(agg);
+
+        final double[][] featuress = new double[][] { {5.1, 3.5, 1.4, 0.2}, 
{4.9, 3.d, 1.4, 0.2},
+                {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.3, 3.3, 6.d, 
2.5},
+                {5.8, 2.7, 5.1, 1.9}};
+
+        final int[][] labelss = new int[][] { {0, 0, 0}, // cause 
UDFArgumentException
+                {1, 0, 0}, {0, 1, 0}, {0, 1, 0}, {0, 0, 1}, {0, 0, 1}};
+
+        for (int i = 0; i < featuress.length; i++) {
+            final List<IntWritable> labels = new ArrayList<IntWritable>();
+            for (int label : labelss[i]) {
+                labels.add(new IntWritable(label));
+            }
+            evaluator.iterate(agg,
+                new Object[] {WritableUtils.toWritableList(featuress[i]), 
labels});
+        }
+    }
+
+    @Test
+    public void shouldFail1() throws Exception {
+        expectedException.expect(IllegalArgumentException.class);
+
+        final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+        final ObjectInspector[] OIs = new ObjectInspector[] {
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+        final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = 
(SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new 
SimpleGenericUDAFParameterInfo(
+            OIs, false, false));
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+        final 
SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer
 agg = 
(SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer)
 evaluator.getNewAggregationBuffer();
+        evaluator.reset(agg);
+
+        final double[][] featuress = new double[][] { {5.1, 3.5, 1.4, 0.2},
+                {4.9, 3.d, 1.4}, // cause IllegalArgumentException
+                {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.3, 3.3, 6.d, 
2.5},
+                {5.8, 2.7, 5.1, 1.9}};
+
+        final int[][] labelss = new int[][] { {1, 0, 0}, {1, 0, 0}, {0, 1, 0}, 
{0, 1, 0},
+                {0, 0, 1}, {0, 0, 1}};
+
+        for (int i = 0; i < featuress.length; i++) {
+            final List<IntWritable> labels = new ArrayList<IntWritable>();
+            for (int label : labelss[i]) {
+                labels.add(new IntWritable(label));
+            }
+            evaluator.iterate(agg,
+                new Object[] {WritableUtils.toWritableList(featuress[i]), 
labels});
+        }
+    }
+
+    @Test
+    public void shouldFail2() throws Exception {
+        expectedException.expect(IllegalArgumentException.class);
+
+        final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+        final ObjectInspector[] OIs = new ObjectInspector[] {
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+        final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = 
(SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new 
SimpleGenericUDAFParameterInfo(
+            OIs, false, false));
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+        final 
SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer
 agg = 
(SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer)
 evaluator.getNewAggregationBuffer();
+        evaluator.reset(agg);
+
+        final double[][] featuress = new double[][] { {5.1, 3.5, 1.4, 0.2}, 
{4.9, 3.d, 1.4, 0.2},
+                {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.3, 3.3, 6.d, 
2.5},
+                {5.8, 2.7, 5.1, 1.9}};
+
+        final int[][] labelss = new int[][] { {1}, {1}, {1}, {1}, {1}, {1}}; 
// cause IllegalArgumentException
+
+        for (int i = 0; i < featuress.length; i++) {
+            final List<IntWritable> labels = new ArrayList<IntWritable>();
+            for (int label : labelss[i]) {
+                labels.add(new IntWritable(label));
+            }
+            evaluator.iterate(agg,
+                new Object[] {WritableUtils.toWritableList(featuress[i]), 
labels});
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/22a608ee/resources/ddl/define-all-as-permanent.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all-as-permanent.hive 
b/resources/ddl/define-all-as-permanent.hive
index b515b24..10e72b7 100644
--- a/resources/ddl/define-all-as-permanent.hive
+++ b/resources/ddl/define-all-as-permanent.hive
@@ -209,6 +209,9 @@ CREATE FUNCTION l2_normalize as 
'hivemall.ftvec.scaling.L2NormalizationUDF' USIN
 DROP FUNCTION IF EXISTS chi2;
 CREATE FUNCTION chi2 as 'hivemall.ftvec.selection.ChiSquareUDF' USING JAR 
'${hivemall_jar}';
 
+DROP FUNCTION IF EXISTS snr;
+CREATE FUNCTION snr as 'hivemall.ftvec.selection.SignalNoiseRatioUDAF' USING 
JAR '${hivemall_jar}';
+
 --------------------
 -- misc functions --
 --------------------

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/22a608ee/resources/ddl/define-all.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive
index 2124892..04b519e 100644
--- a/resources/ddl/define-all.hive
+++ b/resources/ddl/define-all.hive
@@ -205,6 +205,9 @@ create temporary function l2_normalize as 
'hivemall.ftvec.scaling.L2Normalizatio
 drop temporary function chi2;
 create temporary function chi2 as 'hivemall.ftvec.selection.ChiSquareUDF';
 
+drop temporary function snr;
+create temporary function snr as 
'hivemall.ftvec.selection.SignalNoiseRatioUDAF';
+
 -----------------------------------
 -- Feature engineering functions --
 -----------------------------------

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/22a608ee/resources/ddl/define-all.spark
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark
index 47f0ce5..65c2346 100644
--- a/resources/ddl/define-all.spark
+++ b/resources/ddl/define-all.spark
@@ -190,6 +190,9 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION normalize AS 
'hivemall.ftvec.scaling.L
 sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS chi2")
 sqlContext.sql("CREATE TEMPORARY FUNCTION chi2 AS 
'hivemall.ftvec.selection.ChiSquareUDF'")
 
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS snr")
+sqlContext.sql("CREATE TEMPORARY FUNCTION snr AS 
'hivemall.ftvec.selection.SignalNoiseRatioUDAF'")
+
 /**
  * misc functions
  */

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/22a608ee/resources/ddl/define-udfs.td.hql
----------------------------------------------------------------------
diff --git a/resources/ddl/define-udfs.td.hql b/resources/ddl/define-udfs.td.hql
index fd7dc1d..7aa537a 100644
--- a/resources/ddl/define-udfs.td.hql
+++ b/resources/ddl/define-udfs.td.hql
@@ -51,6 +51,7 @@ create temporary function rescale as 
'hivemall.ftvec.scaling.RescaleUDF';
 create temporary function zscore as 'hivemall.ftvec.scaling.ZScoreUDF';
 create temporary function l2_normalize as 
'hivemall.ftvec.scaling.L2NormalizationUDF';
 create temporary function chi2 as 'hivemall.ftvec.selection.ChiSquareUDF';
+create temporary function snr as 
'hivemall.ftvec.selection.SignalNoiseRatioUDAF';
 create temporary function amplify as 'hivemall.ftvec.amplify.AmplifierUDTF';
 create temporary function rand_amplify as 
'hivemall.ftvec.amplify.RandomAmplifierUDTF';
 create temporary function add_bias as 'hivemall.ftvec.AddBiasUDF';

Reply via email to