Close #15: Implement Feature Selection functions (chi2, snr)
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/fad2941f Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/fad2941f Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/fad2941f Branch: refs/heads/master Commit: fad2941fdb0309dd6fbf22f2f936cbf0003b1c4a Parents: 518e232 Author: amaya <amaya...@users.noreply.github.com> Authored: Tue Dec 13 16:46:07 2016 +0900 Committer: myui <yuin...@gmail.com> Committed: Tue Dec 13 16:46:07 2016 +0900 ---------------------------------------------------------------------- .../hivemall/ftvec/selection/ChiSquareUDF.java | 173 +++++++++ .../ftvec/selection/SignalNoiseRatioUDAF.java | 370 +++++++++++++++++++ .../hivemall/tools/array/SelectKBestUDF.java | 163 ++++++++ .../tools/matrix/TransposeAndDotUDAF.java | 222 +++++++++++ .../java/hivemall/utils/hadoop/HiveUtils.java | 22 +- .../hivemall/utils/hadoop/WritableUtils.java | 16 + .../java/hivemall/utils/lang/Preconditions.java | 30 ++ .../java/hivemall/utils/math/StatsUtils.java | 91 +++++ .../ftvec/selection/ChiSquareUDFTest.java | 82 ++++ .../selection/SignalNoiseRatioUDAFTest.java | 342 +++++++++++++++++ .../tools/array/SelectKBeatUDFTest.java | 69 ++++ .../tools/matrix/TransposeAndDotUDAFTest.java | 59 +++ .../hivemall/utils/lang/PreconditionsTest.java | 37 ++ docs/gitbook/SUMMARY.md | 2 + .../gitbook/ft_engineering/feature_selection.md | 155 ++++++++ docs/gitbook/ft_engineering/quantify.md | 2 +- resources/ddl/define-all-as-permanent.hive | 20 + resources/ddl/define-all.hive | 20 + resources/ddl/define-all.spark | 50 ++- resources/ddl/define-udfs.td.hql | 4 + .../apache/spark/sql/hive/GroupedDataEx.scala | 29 +- .../org/apache/spark/sql/hive/HivemallOps.scala | 18 + .../spark/sql/hive/HivemallOpsSuite.scala | 105 +++++- .../spark/sql/hive/HivemallGroupedDataset.scala | 26 ++ .../org/apache/spark/sql/hive/HivemallOps.scala | 20 + .../spark/sql/hive/HivemallOpsSuite.scala | 100 +++++ 26 files changed, 2203 insertions(+), 24 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java new file mode 100644 index 0000000..9ada4e5 --- /dev/null +++ b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.ftvec.selection; + +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.hadoop.WritableUtils; +import hivemall.utils.lang.Preconditions; +import hivemall.utils.math.StatsUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; + +@Description(name = "chi2", + value = "_FUNC_(array<array<number>> observed, array<array<number>> expected)" + + " - Returns chi2_val and p_val of each columns as <array<double>, array<double>>") +@UDFType(deterministic = true, stateful = false) +public final class ChiSquareUDF extends GenericUDF { + + private ListObjectInspector observedOI; + private ListObjectInspector observedRowOI; + private PrimitiveObjectInspector observedElOI; + private ListObjectInspector expectedOI; + private ListObjectInspector expectedRowOI; + private PrimitiveObjectInspector expectedElOI; + + private int nFeatures = -1; + private double[] observedRow = null; // to reuse + private double[] expectedRow = null; // to reuse + private double[][] observed = null; // shape = (#features, #classes) + private double[][] expected = null; // shape = (#features, #classes) + + private List<DoubleWritable>[] result; + + @SuppressWarnings("unchecked") + @Override + public ObjectInspector initialize(ObjectInspector[] OIs) throws UDFArgumentException { + if (OIs.length != 2) { + throw new UDFArgumentLengthException("Specify two arguments: " + OIs.length); + } + if (!HiveUtils.isNumberListListOI(OIs[0])) { + throw new UDFArgumentTypeException(0, + "Only array<array<number>> type argument is acceptable but " + OIs[0].getTypeName() + + " was passed as `observed`"); + } + if (!HiveUtils.isNumberListListOI(OIs[1])) { + throw new UDFArgumentTypeException(1, + "Only array<array<number>> type argument is acceptable but " + OIs[1].getTypeName() + + " was passed as `expected`"); + } + + this.observedOI = HiveUtils.asListOI(OIs[1]); + this.observedRowOI = HiveUtils.asListOI(observedOI.getListElementObjectInspector()); + this.observedElOI = HiveUtils.asDoubleCompatibleOI(observedRowOI.getListElementObjectInspector()); + this.expectedOI = HiveUtils.asListOI(OIs[0]); + this.expectedRowOI = HiveUtils.asListOI(expectedOI.getListElementObjectInspector()); + this.expectedElOI = HiveUtils.asDoubleCompatibleOI(expectedRowOI.getListElementObjectInspector()); + this.result = new List[2]; + + List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)); + + return ObjectInspectorFactory.getStandardStructObjectInspector( + Arrays.asList("chi2", "pvalue"), fieldOIs); + } + + @Override + public List<DoubleWritable>[] evaluate(DeferredObject[] dObj) throws HiveException { + List<?> observedObj = observedOI.getList(dObj[0].get()); // shape = (#classes, #features) + List<?> expectedObj = expectedOI.getList(dObj[1].get()); // shape = (#classes, #features) + + if (observedObj == null || expectedObj == null) { + return null; + } + + final int nClasses = observedObj.size(); + Preconditions.checkArgument(nClasses == expectedObj.size(), UDFArgumentException.class); + + // explode and transpose matrix + for (int i = 0; i < nClasses; i++) { + Object observedObjRow = observedObj.get(i); + Object expectedObjRow = expectedObj.get(i); + + Preconditions.checkNotNull(observedObjRow, UDFArgumentException.class); + Preconditions.checkNotNull(expectedObjRow, UDFArgumentException.class); + + if (observedRow == null) { + observedRow = HiveUtils.asDoubleArray(observedObjRow, observedRowOI, observedElOI, + false); + expectedRow = HiveUtils.asDoubleArray(expectedObjRow, expectedRowOI, expectedElOI, + false); + nFeatures = observedRow.length; + observed = new double[nFeatures][nClasses]; + expected = new double[nFeatures][nClasses]; + } else { + HiveUtils.toDoubleArray(observedObjRow, observedRowOI, observedElOI, observedRow, + false); + HiveUtils.toDoubleArray(expectedObjRow, expectedRowOI, expectedElOI, expectedRow, + false); + } + + for (int j = 0; j < nFeatures; j++) { + observed[j][i] = observedRow[j]; + expected[j][i] = expectedRow[j]; + } + } + + Map.Entry<double[], double[]> chi2 = StatsUtils.chiSquare(observed, expected); + + result[0] = WritableUtils.toWritableList(chi2.getKey(), result[0]); + result[1] = WritableUtils.toWritableList(chi2.getValue(), result[1]); + return result; + } + + @Override + public void close() throws IOException { + // help GC + this.observedRow = null; + this.expectedRow = null; + this.observed = null; + this.expected = null; + this.result = null; + } + + @Override + public String getDisplayString(String[] children) { + final StringBuilder sb = new StringBuilder(); + sb.append("chi2"); + sb.append("("); + if (children.length > 0) { + sb.append(children[0]); + for (int i = 1; i < children.length; i++) { + sb.append(", "); + sb.append(children[i]); + } + } + sb.append(")"); + return sb.toString(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java b/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java new file mode 100644 index 0000000..da0de59 --- /dev/null +++ b/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java @@ -0,0 +1,370 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.ftvec.selection; + +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.hadoop.WritableUtils; +import hivemall.utils.lang.Preconditions; +import hivemall.utils.lang.SizeOf; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import javax.annotation.Nonnull; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; + +@Description(name = "snr", value = "_FUNC_(array<number> features, array<int> one-hot class label)" + + " - Returns Signal Noise Ratio for each feature as array<double>") +public class SignalNoiseRatioUDAF extends AbstractGenericUDAFResolver { + + @Override + public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info) + throws SemanticException { + final ObjectInspector[] OIs = info.getParameterObjectInspectors(); + + if (OIs.length != 2) { + throw new UDFArgumentLengthException("Specify two arguments: " + OIs.length); + } + if (!HiveUtils.isNumberListOI(OIs[0])) { + throw new UDFArgumentTypeException(0, + "Only array<number> type argument is acceptable but " + OIs[0].getTypeName() + + " was passed as `features`"); + } + if (!HiveUtils.isListOI(OIs[1]) + || !HiveUtils.isIntegerOI(((ListObjectInspector) OIs[1]).getListElementObjectInspector())) { + throw new UDFArgumentTypeException(1, + "Only array<int> type argument is acceptable but " + OIs[1].getTypeName() + + " was passed as `labels`"); + } + + return new SignalNoiseRatioUDAFEvaluator(); + } + + static class SignalNoiseRatioUDAFEvaluator extends GenericUDAFEvaluator { + // PARTIAL1 and COMPLETE + private ListObjectInspector featuresOI; + private PrimitiveObjectInspector featureOI; + private ListObjectInspector labelsOI; + private PrimitiveObjectInspector labelOI; + + // PARTIAL2 and FINAL + private StructObjectInspector structOI; + private StructField countsField, meansField, variancesField; + private ListObjectInspector countsOI; + private LongObjectInspector countOI; + private ListObjectInspector meansOI; + private ListObjectInspector meanListOI; + private DoubleObjectInspector meanElemOI; + private ListObjectInspector variancesOI; + private ListObjectInspector varianceListOI; + private DoubleObjectInspector varianceElemOI; + + @AggregationType(estimable = true) + static class SignalNoiseRatioAggregationBuffer extends AbstractAggregationBuffer { + long[] counts; + double[][] means; + double[][] variances; + + @Override + public int estimate() { + return counts == null ? 0 : SizeOf.LONG * counts.length + SizeOf.DOUBLE + * means.length * means[0].length + SizeOf.DOUBLE * variances.length + * variances[0].length; + } + + public void init(int nClasses, int nFeatures) { + this.counts = new long[nClasses]; + this.means = new double[nClasses][nFeatures]; + this.variances = new double[nClasses][nFeatures]; + } + + public void reset() { + if (counts != null) { + Arrays.fill(counts, 0); + for (double[] mean : means) { + Arrays.fill(mean, 0.d); + } + for (double[] variance : variances) { + Arrays.fill(variance, 0.d); + } + } + } + } + + @Override + public ObjectInspector init(Mode mode, ObjectInspector[] OIs) throws HiveException { + super.init(mode, OIs); + + // initialize input + if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data + this.featuresOI = HiveUtils.asListOI(OIs[0]); + this.featureOI = HiveUtils.asDoubleCompatibleOI(featuresOI.getListElementObjectInspector()); + this.labelsOI = HiveUtils.asListOI(OIs[1]); + this.labelOI = HiveUtils.asIntegerOI(labelsOI.getListElementObjectInspector()); + } else {// from partial aggregation + this.structOI = (StructObjectInspector) OIs[0]; + this.countsField = structOI.getStructFieldRef("counts"); + this.countsOI = HiveUtils.asListOI(countsField.getFieldObjectInspector()); + this.countOI = HiveUtils.asLongOI(countsOI.getListElementObjectInspector()); + this.meansField = structOI.getStructFieldRef("means"); + this.meansOI = HiveUtils.asListOI(meansField.getFieldObjectInspector()); + this.meanListOI = HiveUtils.asListOI(meansOI.getListElementObjectInspector()); + this.meanElemOI = HiveUtils.asDoubleOI(meanListOI.getListElementObjectInspector()); + this.variancesField = structOI.getStructFieldRef("variances"); + this.variancesOI = HiveUtils.asListOI(variancesField.getFieldObjectInspector()); + this.varianceListOI = HiveUtils.asListOI(variancesOI.getListElementObjectInspector()); + this.varianceElemOI = HiveUtils.asDoubleOI(varianceListOI.getListElementObjectInspector()); + } + + // initialize output + if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial + List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableLongObjectInspector)); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector))); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector))); + return ObjectInspectorFactory.getStandardStructObjectInspector( + Arrays.asList("counts", "means", "variances"), fieldOIs); + } else {// terminate + return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + } + } + + @Override + public AbstractAggregationBuffer getNewAggregationBuffer() throws HiveException { + SignalNoiseRatioAggregationBuffer myAgg = new SignalNoiseRatioAggregationBuffer(); + reset(myAgg); + return myAgg; + } + + @Override + public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer) agg; + myAgg.reset(); + } + + @Override + public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg, + Object[] parameters) throws HiveException { + final Object featuresObj = parameters[0]; + final Object labelsObj = parameters[1]; + + Preconditions.checkNotNull(featuresObj); + Preconditions.checkNotNull(labelsObj); + + final SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer) agg; + + final List<?> labels = labelsOI.getList(labelsObj); + final int nClasses = labels.size(); + Preconditions.checkArgument(nClasses >= 2, UDFArgumentException.class); + + final List<?> features = featuresOI.getList(featuresObj); + final int nFeatures = features.size(); + Preconditions.checkArgument(nFeatures >= 1, UDFArgumentException.class); + + if (myAgg.counts == null) { + myAgg.init(nClasses, nFeatures); + } else { + Preconditions.checkArgument(nClasses == myAgg.counts.length, + UDFArgumentException.class); + Preconditions.checkArgument(nFeatures == myAgg.means[0].length, + UDFArgumentException.class); + } + + // incrementally calculates means and variance + // http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf + final int clazz = hotIndex(labels, labelOI); + final long n = myAgg.counts[clazz]; + myAgg.counts[clazz]++; + for (int i = 0; i < nFeatures; i++) { + final double x = PrimitiveObjectInspectorUtils.getDouble(features.get(i), featureOI); + final double meanN = myAgg.means[clazz][i]; + final double varianceN = myAgg.variances[clazz][i]; + myAgg.means[clazz][i] = (n * meanN + x) / (n + 1.d); + myAgg.variances[clazz][i] = (n * varianceN + (x - meanN) + * (x - myAgg.means[clazz][i])) + / (n + 1.d); + } + } + + private static int hotIndex(@Nonnull List<?> labels, PrimitiveObjectInspector labelOI) + throws UDFArgumentException { + final int nClasses = labels.size(); + + int clazz = -1; + for (int i = 0; i < nClasses; i++) { + final int label = PrimitiveObjectInspectorUtils.getInt(labels.get(i), labelOI); + if (label == 1) {// assumes one hot encoding + if (clazz != -1) { + throw new UDFArgumentException( + "Specify one-hot vectorized array. Multiple hot elements found."); + } + clazz = i; + } else { + if (label != 0) { + throw new UDFArgumentException( + "Assumed one-hot encoding (0/1) but found an invalid label: " + label); + } + } + } + if (clazz == -1) { + throw new UDFArgumentException( + "Specify one-hot vectorized array for label. Hot element not found."); + } + return clazz; + } + + @Override + public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object other) + throws HiveException { + if (other == null) { + return; + } + + final SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer) agg; + + final List<?> counts = countsOI.getList(structOI.getStructFieldData(other, countsField)); + final List<?> means = meansOI.getList(structOI.getStructFieldData(other, meansField)); + final List<?> variances = variancesOI.getList(structOI.getStructFieldData(other, + variancesField)); + + final int nClasses = counts.size(); + final int nFeatures = meanListOI.getListLength(means.get(0)); + if (myAgg.counts == null) { + myAgg.init(nClasses, nFeatures); + } + + for (int i = 0; i < nClasses; i++) { + final long n = myAgg.counts[i]; + final long cnt = PrimitiveObjectInspectorUtils.getLong(counts.get(i), countOI); + + // no need to merge class `i` + if (cnt == 0) { + continue; + } + + final List<?> mean = meanListOI.getList(means.get(i)); + final List<?> variance = varianceListOI.getList(variances.get(i)); + + myAgg.counts[i] += cnt; + for (int j = 0; j < nFeatures; j++) { + final double meanN = myAgg.means[i][j]; + final double meanM = PrimitiveObjectInspectorUtils.getDouble(mean.get(j), + meanElemOI); + final double varianceN = myAgg.variances[i][j]; + final double varianceM = PrimitiveObjectInspectorUtils.getDouble( + variance.get(j), varianceElemOI); + + if (n == 0) {// only assign `other` into `myAgg` + myAgg.means[i][j] = meanM; + myAgg.variances[i][j] = varianceM; + } else { + // merge by Chan's method + // http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf + myAgg.means[i][j] = (n * meanN + cnt * meanM) / (double) (n + cnt); + myAgg.variances[i][j] = (varianceN * (n - 1) + varianceM * (cnt - 1) + Math.pow( + meanN - meanM, 2) * n * cnt / (n + cnt)) + / (n + cnt - 1); + } + } + } + } + + @Override + public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + final SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer) agg; + + final Object[] partialResult = new Object[3]; + partialResult[0] = WritableUtils.toWritableList(myAgg.counts); + final List<List<DoubleWritable>> means = new ArrayList<List<DoubleWritable>>(); + for (double[] mean : myAgg.means) { + means.add(WritableUtils.toWritableList(mean)); + } + partialResult[1] = means; + final List<List<DoubleWritable>> variances = new ArrayList<List<DoubleWritable>>(); + for (double[] variance : myAgg.variances) { + variances.add(WritableUtils.toWritableList(variance)); + } + partialResult[2] = variances; + return partialResult; + } + + @Override + public Object terminate(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + final SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer) agg; + + final int nClasses = myAgg.counts.length; + final int nFeatures = myAgg.means[0].length; + + // compute SNR among classes for each feature + final double[] result = new double[nFeatures]; + final double[] sds = new double[nClasses]; // for memorization + for (int i = 0; i < nFeatures; i++) { + sds[0] = Math.sqrt(myAgg.variances[0][i]); + for (int j = 1; j < nClasses; j++) { + sds[j] = Math.sqrt(myAgg.variances[j][i]); + // `ns[j] == 0` means no feature entry belongs to class `j`. Then, skip the entry. + if (myAgg.counts[j] == 0) { + continue; + } + for (int k = 0; k < j; k++) { + // avoid comparing between classes having only single entry + if (myAgg.counts[k] == 0 || (myAgg.counts[j] == 1 && myAgg.counts[k] == 1)) { + continue; + } + + // SUM(snr) GROUP BY feature + final double snr = Math.abs(myAgg.means[j][i] - myAgg.means[k][i]) + / (sds[j] + sds[k]); + // if `NaN`(when diff between means and both sds are zero, IOW, all related values are equal), + // regard feature `i` as meaningless between class `j` and `k`. So, skip the entry. + if (!Double.isNaN(snr)) { + result[i] += snr; // accept `Infinity` + } + } + } + } + + return WritableUtils.toWritableList(result); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/main/java/hivemall/tools/array/SelectKBestUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/array/SelectKBestUDF.java b/core/src/main/java/hivemall/tools/array/SelectKBestUDF.java new file mode 100644 index 0000000..b363166 --- /dev/null +++ b/core/src/main/java/hivemall/tools/array/SelectKBestUDF.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools.array; + +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Preconditions; + +import java.io.IOException; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; + +@Description(name = "select_k_best", + value = "_FUNC_(array<number> array, const array<number> importance, const int k)" + + " - Returns selected top-k elements as array<double>") +@UDFType(deterministic = true, stateful = false) +public final class SelectKBestUDF extends GenericUDF { + + private ListObjectInspector featuresOI; + private PrimitiveObjectInspector featureOI; + private ListObjectInspector importanceListOI; + private PrimitiveObjectInspector importanceElemOI; + + private int _k; + private List<DoubleWritable> _result; + private int[] _topKIndices; + + @Override + public ObjectInspector initialize(ObjectInspector[] OIs) throws UDFArgumentException { + if (OIs.length != 3) { + throw new UDFArgumentLengthException("Specify three arguments: " + OIs.length); + } + + if (!HiveUtils.isNumberListOI(OIs[0])) { + throw new UDFArgumentTypeException(0, + "Only array<number> type argument is acceptable but " + OIs[0].getTypeName() + + " was passed as `features`"); + } + if (!HiveUtils.isNumberListOI(OIs[1])) { + throw new UDFArgumentTypeException(1, + "Only array<number> type argument is acceptable but " + OIs[1].getTypeName() + + " was passed as `importance_list`"); + } + if (!HiveUtils.isIntegerOI(OIs[2])) { + throw new UDFArgumentTypeException(2, "Only int type argument is acceptable but " + + OIs[2].getTypeName() + " was passed as `k`"); + } + + this.featuresOI = HiveUtils.asListOI(OIs[0]); + this.featureOI = HiveUtils.asDoubleCompatibleOI(featuresOI.getListElementObjectInspector()); + this.importanceListOI = HiveUtils.asListOI(OIs[1]); + this.importanceElemOI = HiveUtils.asDoubleCompatibleOI(importanceListOI.getListElementObjectInspector()); + + this._k = HiveUtils.getConstInt(OIs[2]); + Preconditions.checkArgument(_k >= 1, UDFArgumentException.class); + final DoubleWritable[] array = new DoubleWritable[_k]; + for (int i = 0; i < array.length; i++) { + array[i] = new DoubleWritable(); + } + this._result = Arrays.asList(array); + + return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + } + + @Override + public List<DoubleWritable> evaluate(DeferredObject[] dObj) throws HiveException { + final double[] features = HiveUtils.asDoubleArray(dObj[0].get(), featuresOI, featureOI); + final double[] importanceList = HiveUtils.asDoubleArray(dObj[1].get(), importanceListOI, + importanceElemOI); + + Preconditions.checkNotNull(features, UDFArgumentException.class); + Preconditions.checkNotNull(importanceList, UDFArgumentException.class); + Preconditions.checkArgument(features.length == importanceList.length, + UDFArgumentException.class); + Preconditions.checkArgument(features.length >= _k, UDFArgumentException.class); + + int[] topKIndices = _topKIndices; + if (topKIndices == null) { + final List<Map.Entry<Integer, Double>> list = new ArrayList<Map.Entry<Integer, Double>>(); + for (int i = 0; i < importanceList.length; i++) { + list.add(new AbstractMap.SimpleEntry<Integer, Double>(i, importanceList[i])); + } + Collections.sort(list, new Comparator<Map.Entry<Integer, Double>>() { + @Override + public int compare(Map.Entry<Integer, Double> o1, Map.Entry<Integer, Double> o2) { + return o1.getValue() > o2.getValue() ? -1 : 1; + } + }); + + topKIndices = new int[_k]; + for (int i = 0; i < topKIndices.length; i++) { + topKIndices[i] = list.get(i).getKey(); + } + this._topKIndices = topKIndices; + } + + final List<DoubleWritable> result = _result; + for (int i = 0; i < topKIndices.length; i++) { + int idx = topKIndices[i]; + DoubleWritable d = result.get(i); + double f = features[idx]; + d.set(f); + } + return result; + } + + @Override + public void close() throws IOException { + // help GC + this._result = null; + this._topKIndices = null; + } + + @Override + public String getDisplayString(String[] children) { + final StringBuilder sb = new StringBuilder(); + sb.append("select_k_best"); + sb.append("("); + if (children.length > 0) { + sb.append(children[0]); + for (int i = 1; i < children.length; i++) { + sb.append(", "); + sb.append(children[i]); + } + } + sb.append(")"); + return sb.toString(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java b/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java new file mode 100644 index 0000000..440bbe6 --- /dev/null +++ b/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools.matrix; + +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.hadoop.WritableUtils; +import hivemall.utils.lang.Preconditions; +import hivemall.utils.lang.SizeOf; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; + +@Description( + name = "transpose_and_dot", + value = "_FUNC_(array<number> matrix0_row, array<number> matrix1_row)" + + " - Returns dot(matrix0.T, matrix1) as array<array<double>>, shape = (matrix0.#cols, matrix1.#cols)") +public final class TransposeAndDotUDAF extends AbstractGenericUDAFResolver { + + @Override + public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info) + throws SemanticException { + ObjectInspector[] OIs = info.getParameterObjectInspectors(); + + if (OIs.length != 2) { + throw new UDFArgumentLengthException("Specify two arguments."); + } + + if (!HiveUtils.isNumberListOI(OIs[0])) { + throw new UDFArgumentTypeException(0, + "Only array<number> type argument is acceptable but " + OIs[0].getTypeName() + + " was passed as `matrix0_row`"); + } + + if (!HiveUtils.isNumberListOI(OIs[1])) { + throw new UDFArgumentTypeException(1, + "Only array<number> type argument is acceptable but " + OIs[1].getTypeName() + + " was passed as `matrix1_row`"); + } + + return new TransposeAndDotUDAFEvaluator(); + } + + static final class TransposeAndDotUDAFEvaluator extends GenericUDAFEvaluator { + // PARTIAL1 and COMPLETE + private ListObjectInspector matrix0RowOI; + private PrimitiveObjectInspector matrix0ElOI; + private ListObjectInspector matrix1RowOI; + private PrimitiveObjectInspector matrix1ElOI; + + // PARTIAL2 and FINAL + private ListObjectInspector aggMatrixOI; + private ListObjectInspector aggMatrixRowOI; + private DoubleObjectInspector aggMatrixElOI; + + private double[] matrix0Row; + private double[] matrix1Row; + + @AggregationType(estimable = true) + static class TransposeAndDotAggregationBuffer extends AbstractAggregationBuffer { + double[][] aggMatrix; + + @Override + public int estimate() { + return aggMatrix != null ? aggMatrix.length * aggMatrix[0].length * SizeOf.DOUBLE + : 0; + } + + public void init(int n, int m) { + this.aggMatrix = new double[n][m]; + } + + public void reset() { + if (aggMatrix != null) { + for (double[] row : aggMatrix) { + Arrays.fill(row, 0.d); + } + } + } + } + + @Override + public ObjectInspector init(Mode mode, ObjectInspector[] OIs) throws HiveException { + super.init(mode, OIs); + + if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) { + this.matrix0RowOI = HiveUtils.asListOI(OIs[0]); + this.matrix0ElOI = HiveUtils.asDoubleCompatibleOI(matrix0RowOI.getListElementObjectInspector()); + this.matrix1RowOI = HiveUtils.asListOI(OIs[1]); + this.matrix1ElOI = HiveUtils.asDoubleCompatibleOI(matrix1RowOI.getListElementObjectInspector()); + } else { + this.aggMatrixOI = HiveUtils.asListOI(OIs[0]); + this.aggMatrixRowOI = HiveUtils.asListOI(aggMatrixOI.getListElementObjectInspector()); + this.aggMatrixElOI = HiveUtils.asDoubleOI(aggMatrixRowOI.getListElementObjectInspector()); + } + + return ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)); + } + + @Override + public AbstractAggregationBuffer getNewAggregationBuffer() throws HiveException { + TransposeAndDotAggregationBuffer myAgg = new TransposeAndDotAggregationBuffer(); + reset(myAgg); + return myAgg; + } + + @Override + public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg; + myAgg.reset(); + } + + @Override + public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg, + Object[] parameters) throws HiveException { + final Object matrix0RowObj = parameters[0]; + final Object matrix1RowObj = parameters[1]; + + Preconditions.checkNotNull(matrix0RowObj, UDFArgumentException.class); + Preconditions.checkNotNull(matrix1RowObj, UDFArgumentException.class); + + final TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg; + + if (matrix0Row == null) { + matrix0Row = new double[matrix0RowOI.getListLength(matrix0RowObj)]; + } + if (matrix1Row == null) { + matrix1Row = new double[matrix1RowOI.getListLength(matrix1RowObj)]; + } + + HiveUtils.toDoubleArray(matrix0RowObj, matrix0RowOI, matrix0ElOI, matrix0Row, false); + HiveUtils.toDoubleArray(matrix1RowObj, matrix1RowOI, matrix1ElOI, matrix1Row, false); + + if (myAgg.aggMatrix == null) { + myAgg.init(matrix0Row.length, matrix1Row.length); + } + + for (int i = 0; i < matrix0Row.length; i++) { + for (int j = 0; j < matrix1Row.length; j++) { + myAgg.aggMatrix[i][j] += matrix0Row[i] * matrix1Row[j]; + } + } + } + + @Override + public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object other) + throws HiveException { + if (other == null) { + return; + } + + final TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg; + + final List<?> matrix = aggMatrixOI.getList(other); + final int n = matrix.size(); + final double[] row = new double[aggMatrixRowOI.getListLength(matrix.get(0))]; + for (int i = 0; i < n; i++) { + HiveUtils.toDoubleArray(matrix.get(i), aggMatrixRowOI, aggMatrixElOI, row, false); + + if (myAgg.aggMatrix == null) { + myAgg.init(n, row.length); + } + + for (int j = 0; j < row.length; j++) { + myAgg.aggMatrix[i][j] += row[j]; + } + } + } + + @Override + public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + return terminate(agg); + } + + @Override + public Object terminate(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + final TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg; + + final List<List<DoubleWritable>> result = new ArrayList<List<DoubleWritable>>(); + for (double[] row : myAgg.aggMatrix) { + result.add(WritableUtils.toWritableList(row)); + } + return result; + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java index d8b1aef..8188b7a 100644 --- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java +++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java @@ -59,6 +59,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; @@ -200,8 +201,7 @@ public final class HiveUtils { return BOOLEAN_TYPE_NAME.equals(typeName); } - public static boolean isNumberOI(@Nonnull final ObjectInspector argOI) - throws UDFArgumentTypeException { + public static boolean isNumberOI(@Nonnull final ObjectInspector argOI) { if (argOI.getCategory() != Category.PRIMITIVE) { return false; } @@ -246,6 +246,16 @@ public final class HiveUtils { return oi.getCategory() == Category.MAP; } + public static boolean isNumberListOI(@Nonnull final ObjectInspector oi) { + return isListOI(oi) + && isNumberOI(((ListObjectInspector) oi).getListElementObjectInspector()); + } + + public static boolean isNumberListListOI(@Nonnull final ObjectInspector oi) { + return isListOI(oi) + && isNumberListOI(((ListObjectInspector) oi).getListElementObjectInspector()); + } + public static boolean isPrimitiveTypeInfo(@Nonnull TypeInfo typeInfo) { return typeInfo.getCategory() == ObjectInspector.Category.PRIMITIVE; } @@ -687,6 +697,14 @@ public final class HiveUtils { return (LongObjectInspector) argOI; } + public static DoubleObjectInspector asDoubleOI(@Nonnull final ObjectInspector argOI) + throws UDFArgumentException { + if (!DOUBLE_TYPE_NAME.equals(argOI.getTypeName())) { + throw new UDFArgumentException("Argument type must be DOUBLE: " + argOI.getTypeName()); + } + return (DoubleObjectInspector) argOI; + } + public static PrimitiveObjectInspector asIntCompatibleOI(@Nonnull final ObjectInspector argOI) throws UDFArgumentTypeException { if (argOI.getCategory() != Category.PRIMITIVE) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/main/java/hivemall/utils/hadoop/WritableUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/hadoop/WritableUtils.java b/core/src/main/java/hivemall/utils/hadoop/WritableUtils.java index a4f2691..a9c7390 100644 --- a/core/src/main/java/hivemall/utils/hadoop/WritableUtils.java +++ b/core/src/main/java/hivemall/utils/hadoop/WritableUtils.java @@ -25,7 +25,9 @@ import java.util.List; import javax.annotation.CheckForNull; import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; @@ -142,6 +144,20 @@ public final class WritableUtils { return list; } + @Nonnull + public static List<DoubleWritable> toWritableList(@Nonnull final double[] src, + @Nullable List<DoubleWritable> list) throws UDFArgumentException { + if (list == null) { + return toWritableList(src); + } + + Preconditions.checkArgument(src.length == list.size(), UDFArgumentException.class); + for (int i = 0; i < src.length; i++) { + list.set(i, new DoubleWritable(src[i])); + } + return list; + } + public static Text val(final String v) { return new Text(v); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/main/java/hivemall/utils/lang/Preconditions.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/lang/Preconditions.java b/core/src/main/java/hivemall/utils/lang/Preconditions.java index 4fa2bdd..9f76bd6 100644 --- a/core/src/main/java/hivemall/utils/lang/Preconditions.java +++ b/core/src/main/java/hivemall/utils/lang/Preconditions.java @@ -18,6 +18,7 @@ */ package hivemall.utils.lang; +import javax.annotation.Nonnull; import javax.annotation.Nullable; public final class Preconditions { @@ -31,6 +32,21 @@ public final class Preconditions { return reference; } + public static <T, E extends Throwable> T checkNotNull(@Nullable T reference, + @Nonnull Class<E> clazz) throws E { + if (reference == null) { + final E throwable; + try { + throwable = clazz.newInstance(); + } catch (InstantiationException | IllegalAccessException e) { + throw new IllegalStateException( + "Failed to instantiate a class: " + clazz.getName(), e); + } + throw throwable; + } + return reference; + } + public static <T> T checkNotNull(T reference, @Nullable Object errorMessage) { if (reference == null) { throw new NullPointerException(String.valueOf(errorMessage)); @@ -44,6 +60,20 @@ public final class Preconditions { } } + public static <E extends Throwable> void checkArgument(boolean expression, + @Nonnull Class<E> clazz) throws E { + if (!expression) { + final E throwable; + try { + throwable = clazz.newInstance(); + } catch (InstantiationException | IllegalAccessException e) { + throw new IllegalStateException( + "Failed to instantiate a class: " + clazz.getName(), e); + } + throw throwable; + } + } + public static void checkArgument(boolean expression, @Nullable Object errorMessage) { if (!expression) { throw new IllegalArgumentException(String.valueOf(errorMessage)); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/main/java/hivemall/utils/math/StatsUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/math/StatsUtils.java b/core/src/main/java/hivemall/utils/math/StatsUtils.java index 812f619..599bf51 100644 --- a/core/src/main/java/hivemall/utils/math/StatsUtils.java +++ b/core/src/main/java/hivemall/utils/math/StatsUtils.java @@ -22,11 +22,19 @@ import hivemall.utils.lang.Preconditions; import javax.annotation.Nonnull; +import org.apache.commons.math3.distribution.ChiSquaredDistribution; +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.exception.NotPositiveException; import org.apache.commons.math3.linear.DecompositionSolver; import org.apache.commons.math3.linear.LUDecomposition; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealVector; import org.apache.commons.math3.linear.SingularValueDecomposition; +import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.MathArrays; + +import java.util.AbstractMap; +import java.util.Map; public final class StatsUtils { @@ -189,4 +197,87 @@ public final class StatsUtils { return 1.d - numerator / denominator; } + /** + * @param observed means non-negative vector + * @param expected means positive vector + * @return chi2 value + */ + public static double chiSquare(@Nonnull final double[] observed, + @Nonnull final double[] expected) { + if (observed.length < 2) { + throw new DimensionMismatchException(observed.length, 2); + } + if (expected.length != observed.length) { + throw new DimensionMismatchException(observed.length, expected.length); + } + MathArrays.checkPositive(expected); + for (double d : observed) { + if (d < 0.d) { + throw new NotPositiveException(d); + } + } + + double sumObserved = 0.d; + double sumExpected = 0.d; + for (int i = 0; i < observed.length; i++) { + sumObserved += observed[i]; + sumExpected += expected[i]; + } + double ratio = 1.d; + boolean rescale = false; + if (FastMath.abs(sumObserved - sumExpected) > 10e-6) { + ratio = sumObserved / sumExpected; + rescale = true; + } + double sumSq = 0.d; + for (int i = 0; i < observed.length; i++) { + if (rescale) { + final double dev = observed[i] - ratio * expected[i]; + sumSq += dev * dev / (ratio * expected[i]); + } else { + final double dev = observed[i] - expected[i]; + sumSq += dev * dev / expected[i]; + } + } + return sumSq; + } + + /** + * @param observed means non-negative vector + * @param expected means positive vector + * @return p value + */ + public static double chiSquareTest(@Nonnull final double[] observed, + @Nonnull final double[] expected) { + final ChiSquaredDistribution distribution = new ChiSquaredDistribution( + expected.length - 1.d); + return 1.d - distribution.cumulativeProbability(chiSquare(observed, expected)); + } + + /** + * This method offers effective calculation for multiple entries rather than calculation + * individually + * + * @param observeds means non-negative matrix + * @param expecteds means positive matrix + * @return (chi2 value[], p value[]) + */ + public static Map.Entry<double[], double[]> chiSquare(@Nonnull final double[][] observeds, + @Nonnull final double[][] expecteds) { + Preconditions.checkArgument(observeds.length == expecteds.length); + + final int len = expecteds.length; + final int lenOfEach = expecteds[0].length; + + final ChiSquaredDistribution distribution = new ChiSquaredDistribution(lenOfEach - 1.d); + + final double[] chi2s = new double[len]; + final double[] ps = new double[len]; + for (int i = 0; i < len; i++) { + chi2s[i] = chiSquare(observeds[i], expecteds[i]); + ps[i] = 1.d - distribution.cumulativeProbability(chi2s[i]); + } + + return new AbstractMap.SimpleEntry<double[], double[]>(chi2s, ps); + } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java b/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java new file mode 100644 index 0000000..fd742bb --- /dev/null +++ b/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.ftvec.selection; + +import hivemall.utils.hadoop.WritableUtils; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.junit.Assert; +import org.junit.Test; + +public class ChiSquareUDFTest { + + @Test + public void testIris() throws Exception { + final ChiSquareUDF chi2 = new ChiSquareUDF(); + final List<List<DoubleWritable>> observed = new ArrayList<List<DoubleWritable>>(); + final List<List<DoubleWritable>> expected = new ArrayList<List<DoubleWritable>>(); + final GenericUDF.DeferredObject[] dObjs = new GenericUDF.DeferredObject[] { + new GenericUDF.DeferredJavaObject(observed), + new GenericUDF.DeferredJavaObject(expected)}; + + final double[][] matrix0 = new double[][] { + {250.29999999999998, 170.90000000000003, 73.2, 12.199999999999996}, + {296.8, 138.50000000000003, 212.99999999999997, 66.3}, + {329.3999999999999, 148.7, 277.59999999999997, 101.29999999999998}}; + final double[][] matrix1 = new double[][] { + {292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589}, + {292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589}, + {292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589}}; + + for (double[] row : matrix0) { + observed.add(WritableUtils.toWritableList(row)); + } + for (double[] row : matrix1) { + expected.add(WritableUtils.toWritableList(row)); + } + + chi2.initialize(new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)), + ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector))}); + final List<DoubleWritable>[] result = chi2.evaluate(dObjs); + final double[] result0 = new double[matrix0[0].length]; + final double[] result1 = new double[matrix0[0].length]; + for (int i = 0; i < result0.length; i++) { + result0[i] = result[0].get(i).get(); + result1[i] = result[1].get(i).get(); + } + + // compare results to one of scikit-learn + final double[] answer0 = new double[] {10.81782088, 3.59449902, 116.16984746, 67.24482759}; + final double[] answer1 = new double[] {4.47651499e-03, 1.65754167e-01, 5.94344354e-26, + 2.50017968e-15}; + + Assert.assertArrayEquals(answer0, result0, 1e-5); + Assert.assertArrayEquals(answer1, result1, 1e-5); + chi2.close(); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java b/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java new file mode 100644 index 0000000..79570e3 --- /dev/null +++ b/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java @@ -0,0 +1,342 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.ftvec.selection; + +import hivemall.utils.hadoop.WritableUtils; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.io.IntWritable; +import org.junit.Assert; +import org.junit.Test; + +public class SignalNoiseRatioUDAFTest { + + @Test + public void snrBinaryClass() throws Exception { + // this test is based on *subset* of iris data set + final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF(); + final ObjectInspector[] OIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)}; + final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo( + OIs, false, false)); + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs); + final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer(); + evaluator.reset(agg); + + final double[][] features = new double[][] { {5.1, 3.5, 1.4, 0.2}, {4.9, 3.d, 1.4, 0.2}, + {4.7, 3.2, 1.3, 0.2}, {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, + {6.9, 3.1, 4.9, 1.5}}; + + final int[][] labels = new int[][] { {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1}}; + + for (int i = 0; i < features.length; i++) { + final List<IntWritable> labelList = new ArrayList<IntWritable>(); + for (int label : labels[i]) { + labelList.add(new IntWritable(label)); + } + evaluator.iterate(agg, new Object[] {WritableUtils.toWritableList(features[i]), + labelList}); + } + + @SuppressWarnings("unchecked") + final List<DoubleWritable> resultObj = (List<DoubleWritable>) evaluator.terminate(agg); + final int size = resultObj.size(); + final double[] result = new double[size]; + for (int i = 0; i < size; i++) { + result[i] = resultObj.get(i).get(); + } + + // compare with result by numpy + final double[] answer = new double[] {4.38425236, 0.26390002, 15.83984511, 26.87005769}; + + Assert.assertArrayEquals(answer, result, 1e-5); + } + + @Test + public void snrMultipleClassNormalCase() throws Exception { + // this test is based on *subset* of iris data set + final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF(); + final ObjectInspector[] OIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)}; + final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo( + OIs, false, false)); + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs); + final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer(); + evaluator.reset(agg); + + final double[][] features = new double[][] { {5.1, 3.5, 1.4, 0.2}, {4.9, 3.d, 1.4, 0.2}, + {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.3, 3.3, 6.d, 2.5}, + {5.8, 2.7, 5.1, 1.9}}; + + final int[][] labels = new int[][] { {1, 0, 0}, {1, 0, 0}, {0, 1, 0}, {0, 1, 0}, {0, 0, 1}, + {0, 0, 1}}; + + for (int i = 0; i < features.length; i++) { + final List<IntWritable> labelList = new ArrayList<IntWritable>(); + for (int label : labels[i]) { + labelList.add(new IntWritable(label)); + } + evaluator.iterate(agg, new Object[] {WritableUtils.toWritableList(features[i]), + labelList}); + } + + @SuppressWarnings("unchecked") + final List<DoubleWritable> resultObj = (List<DoubleWritable>) evaluator.terminate(agg); + final int size = resultObj.size(); + final double[] result = new double[size]; + for (int i = 0; i < size; i++) { + result[i] = resultObj.get(i).get(); + } + + // compare with result by scikit-learn + final double[] answer = new double[] {8.43181818, 1.32121212, 42.94949495, 33.80952381}; + + Assert.assertArrayEquals(answer, result, 1e-5); + } + + @Test + public void snrMultipleClassCornerCase0() throws Exception { + final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF(); + final ObjectInspector[] OIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)}; + final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo( + OIs, false, false)); + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs); + final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer(); + evaluator.reset(agg); + + // all c0[0] and c1[0] are equal + // all c1[1] and c2[1] are equal + // all c*[2] are equal + // all c*[3] are different + final double[][] features = new double[][] { {3.5, 1.4, 0.3, 5.1}, {3.5, 1.5, 0.3, 5.2}, + {3.5, 4.5, 0.3, 7.d}, {3.5, 4.5, 0.3, 6.4}, {3.3, 4.5, 0.3, 6.3}}; + + final int[][] labels = new int[][] { {1, 0, 0}, {1, 0, 0}, // class `0` + {0, 1, 0}, {0, 1, 0}, // class `1` + {0, 0, 1}}; // class `2`, only single entry + + for (int i = 0; i < features.length; i++) { + final List<IntWritable> labelList = new ArrayList<IntWritable>(); + for (int label : labels[i]) { + labelList.add(new IntWritable(label)); + } + evaluator.iterate(agg, new Object[] {WritableUtils.toWritableList(features[i]), + labelList}); + } + + @SuppressWarnings("unchecked") + final List<DoubleWritable> resultObj = (List<DoubleWritable>) evaluator.terminate(agg); + final int size = resultObj.size(); + final double[] result = new double[size]; + for (int i = 0; i < size; i++) { + result[i] = resultObj.get(i).get(); + } + + final double[] answer = new double[] {Double.POSITIVE_INFINITY, 121.99999999999989, 0.d, + 28.761904761904734}; + + Assert.assertArrayEquals(answer, result, 1e-5); + } + + @Test + public void snrMultipleClassCornerCase1() throws Exception { + final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF(); + final ObjectInspector[] OIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)}; + final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo( + OIs, false, false)); + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs); + final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer(); + evaluator.reset(agg); + + final double[][] features = new double[][] { {5.1, 3.5, 1.4, 0.2}, {4.9, 3.d, 1.4, 0.2}, + {7.d, 3.2, 4.7, 1.4}, {6.3, 3.3, 6.d, 2.5}, {6.4, 3.2, 4.5, 1.5}}; + + // has multiple single entries + final int[][] labels = new int[][] { {1, 0, 0}, {1, 0, 0}, {1, 0, 0}, // class `0` + {0, 1, 0}, // class `1`, only single entry + {0, 0, 1}}; // class `2`, only single entry + + for (int i = 0; i < features.length; i++) { + final List<IntWritable> labelList = new ArrayList<IntWritable>(); + for (int label : labels[i]) { + labelList.add(new IntWritable(label)); + } + evaluator.iterate(agg, new Object[] {WritableUtils.toWritableList(features[i]), + labelList}); + } + + @SuppressWarnings("unchecked") + final List<DoubleWritable> resultObj = (List<DoubleWritable>) evaluator.terminate(agg); + final List<Double> result = new ArrayList<Double>(); + for (DoubleWritable dw : resultObj) { + result.add(dw.get()); + } + + Assert.assertFalse(result.contains(Double.POSITIVE_INFINITY)); + } + + @Test + public void snrMultipleClassCornerCase2() throws Exception { + final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF(); + final ObjectInspector[] OIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)}; + final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo( + OIs, false, false)); + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs); + final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer(); + evaluator.reset(agg); + + // all [0] are equal + // all [1] are equal *each class* + final double[][] features = new double[][] { {1.d, 1.d, 1.4, 0.2}, {1.d, 1.d, 1.4, 0.2}, + {1.d, 2.d, 4.7, 1.4}, {1.d, 2.d, 4.5, 1.5}, {1.d, 3.d, 6.d, 2.5}, + {1.d, 3.d, 5.1, 1.9}}; + + final int[][] labels = new int[][] { {1, 0, 0}, {1, 0, 0}, {0, 1, 0}, {0, 1, 0}, {0, 0, 1}, + {0, 0, 1}}; + + for (int i = 0; i < features.length; i++) { + final List<IntWritable> labelList = new ArrayList<IntWritable>(); + for (int label : labels[i]) { + labelList.add(new IntWritable(label)); + } + evaluator.iterate(agg, new Object[] {WritableUtils.toWritableList(features[i]), + labelList}); + } + + @SuppressWarnings("unchecked") + final List<DoubleWritable> resultObj = (List<DoubleWritable>) evaluator.terminate(agg); + final int size = resultObj.size(); + final double[] result = new double[size]; + for (int i = 0; i < size; i++) { + result[i] = resultObj.get(i).get(); + } + + final double[] answer = new double[] {0.d, Double.POSITIVE_INFINITY, 42.94949495, + 33.80952381}; + + Assert.assertArrayEquals(answer, result, 1e-5); + } + + @Test(expected = UDFArgumentException.class) + public void shouldFail0() throws Exception { + final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF(); + + final ObjectInspector[] OIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)}; + final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo( + OIs, false, false)); + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs); + final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer(); + evaluator.reset(agg); + + final double[][] featuress = new double[][] { {5.1, 3.5, 1.4, 0.2}, {4.9, 3.d, 1.4, 0.2}, + {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.3, 3.3, 6.d, 2.5}, + {5.8, 2.7, 5.1, 1.9}}; + + final int[][] labelss = new int[][] { {0, 0, 0}, // cause UDFArgumentException + {1, 0, 0}, {0, 1, 0}, {0, 1, 0}, {0, 0, 1}, {0, 0, 1}}; + + for (int i = 0; i < featuress.length; i++) { + final List<IntWritable> labels = new ArrayList<IntWritable>(); + for (int label : labelss[i]) { + labels.add(new IntWritable(label)); + } + evaluator.iterate(agg, + new Object[] {WritableUtils.toWritableList(featuress[i]), labels}); + } + } + + @Test(expected = UDFArgumentException.class) + public void shouldFail1() throws Exception { + final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF(); + + final ObjectInspector[] OIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)}; + final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo( + OIs, false, false)); + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs); + final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer(); + evaluator.reset(agg); + + final double[][] featuress = new double[][] { {5.1, 3.5, 1.4, 0.2}, + {4.9, 3.d, 1.4}, // cause IllegalArgumentException + {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.3, 3.3, 6.d, 2.5}, + {5.8, 2.7, 5.1, 1.9}}; + + final int[][] labelss = new int[][] { {1, 0, 0}, {1, 0, 0}, {0, 1, 0}, {0, 1, 0}, + {0, 0, 1}, {0, 0, 1}}; + + for (int i = 0; i < featuress.length; i++) { + final List<IntWritable> labels = new ArrayList<IntWritable>(); + for (int label : labelss[i]) { + labels.add(new IntWritable(label)); + } + evaluator.iterate(agg, + new Object[] {WritableUtils.toWritableList(featuress[i]), labels}); + } + } + + @Test(expected = UDFArgumentException.class) + public void shouldFail2() throws Exception { + final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF(); + + final ObjectInspector[] OIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)}; + final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo( + OIs, false, false)); + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs); + final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer(); + evaluator.reset(agg); + + final double[][] featuress = new double[][] { {5.1, 3.5, 1.4, 0.2}, {4.9, 3.d, 1.4, 0.2}, + {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.3, 3.3, 6.d, 2.5}, + {5.8, 2.7, 5.1, 1.9}}; + + final int[][] labelss = new int[][] { {1}, {1}, {1}, {1}, {1}, {1}}; // cause IllegalArgumentException + + for (int i = 0; i < featuress.length; i++) { + final List<IntWritable> labels = new ArrayList<IntWritable>(); + for (int label : labelss[i]) { + labels.add(new IntWritable(label)); + } + evaluator.iterate(agg, + new Object[] {WritableUtils.toWritableList(featuress[i]), labels}); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/test/java/hivemall/tools/array/SelectKBeatUDFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/tools/array/SelectKBeatUDFTest.java b/core/src/test/java/hivemall/tools/array/SelectKBeatUDFTest.java new file mode 100644 index 0000000..3e3fc12 --- /dev/null +++ b/core/src/test/java/hivemall/tools/array/SelectKBeatUDFTest.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools.array; + +import hivemall.utils.hadoop.WritableUtils; + +import java.util.List; + +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.junit.Assert; +import org.junit.Test; + +public class SelectKBeatUDFTest { + + @Test + public void test() throws Exception { + final SelectKBestUDF selectKBest = new SelectKBestUDF(); + final int k = 2; + final double[] data = new double[] {250.29999999999998, 170.90000000000003, 73.2, + 12.199999999999996}; + final double[] importanceList = new double[] {292.1666753739119, 152.70000455081467, + 187.93333893418327, 59.93333511948589}; + + final GenericUDF.DeferredObject[] dObjs = new GenericUDF.DeferredObject[] { + new GenericUDF.DeferredJavaObject(WritableUtils.toWritableList(data)), + new GenericUDF.DeferredJavaObject(WritableUtils.toWritableList(importanceList)), + new GenericUDF.DeferredJavaObject(k)}; + + selectKBest.initialize(new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaIntObjectInspector, k)}); + final List<DoubleWritable> resultObj = selectKBest.evaluate(dObjs); + + Assert.assertEquals(resultObj.size(), k); + + final double[] result = new double[k]; + for (int i = 0; i < k; i++) { + result[i] = resultObj.get(i).get(); + } + + final double[] answer = new double[] {250.29999999999998, 73.2}; + + Assert.assertArrayEquals(answer, result, 0.d); + selectKBest.close(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/test/java/hivemall/tools/matrix/TransposeAndDotUDAFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/tools/matrix/TransposeAndDotUDAFTest.java b/core/src/test/java/hivemall/tools/matrix/TransposeAndDotUDAFTest.java new file mode 100644 index 0000000..f705a89 --- /dev/null +++ b/core/src/test/java/hivemall/tools/matrix/TransposeAndDotUDAFTest.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools.matrix; + +import hivemall.utils.hadoop.WritableUtils; + +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.junit.Assert; +import org.junit.Test; + +public class TransposeAndDotUDAFTest { + + @Test + public void test() throws Exception { + final TransposeAndDotUDAF tad = new TransposeAndDotUDAF(); + + final double[][] matrix0 = new double[][] { {1, -2}, {-1, 3}}; + final double[][] matrix1 = new double[][] { {1, 2}, {3, 4}}; + + final ObjectInspector[] OIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)}; + final GenericUDAFEvaluator evaluator = tad.getEvaluator(new SimpleGenericUDAFParameterInfo( + OIs, false, false)); + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs); + TransposeAndDotUDAF.TransposeAndDotUDAFEvaluator.TransposeAndDotAggregationBuffer agg = (TransposeAndDotUDAF.TransposeAndDotUDAFEvaluator.TransposeAndDotAggregationBuffer) evaluator.getNewAggregationBuffer(); + evaluator.reset(agg); + for (int i = 0; i < matrix0.length; i++) { + evaluator.iterate(agg, new Object[] {WritableUtils.toWritableList(matrix0[i]), + WritableUtils.toWritableList(matrix1[i])}); + } + + final double[][] answer = new double[][] { {-2.0, -2.0}, {7.0, 8.0}}; + + for (int i = 0; i < answer.length; i++) { + Assert.assertArrayEquals(answer[i], agg.aggMatrix[i], 0.d); + } + } +}