mod SNR for corner cases
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/4cfa4e5a Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/4cfa4e5a Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/4cfa4e5a Branch: refs/heads/JIRA-22/pr-385 Commit: 4cfa4e5ac15a6535b187c23616c205696a1cd13b Parents: 8e2842c Author: amaya <g...@sapphire.in.net> Authored: Wed Sep 28 18:26:01 2016 +0900 Committer: amaya <g...@sapphire.in.net> Committed: Wed Sep 28 18:29:28 2016 +0900 ---------------------------------------------------------------------- .../ftvec/selection/SignalNoiseRatioUDAF.java | 48 +++++-- .../selection/SignalNoiseRatioUDAFTest.java | 135 ++++++++++++++++++- 2 files changed, 167 insertions(+), 16 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4cfa4e5a/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java b/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java index b7b9126..507aefa 100644 --- a/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java +++ b/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java @@ -21,7 +21,6 @@ package hivemall.ftvec.selection; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.hadoop.WritableUtils; import hivemall.utils.lang.Preconditions; -import org.apache.commons.math3.util.FastMath; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; @@ -193,7 +192,7 @@ public class SignalNoiseRatioUDAF extends AbstractGenericUDAFResolver { int clazz = -1; for (int i = 0; i < nClasses; i++) { - int label = PrimitiveObjectInspectorUtils.getInt(labels.get(i), labelOI); + final int label = PrimitiveObjectInspectorUtils.getInt(labels.get(i), labelOI); if (label == 1 && clazz == -1) { clazz = i; } else if (label == 1) { @@ -255,6 +254,12 @@ public class SignalNoiseRatioUDAF extends AbstractGenericUDAFResolver { for (int i = 0; i < nClasses; i++) { final long n = myAgg.ns[i]; final long m = PrimitiveObjectInspectorUtils.getLong(ns.get(i), nOI); + + // no need to merge class `i` + if (m == 0) { + continue; + } + final List means = meansOI.getList(meanss.get(i)); final List variances = variancesOI.getList(variancess.get(i)); @@ -266,10 +271,19 @@ public class SignalNoiseRatioUDAF extends AbstractGenericUDAFResolver { final double varianceN = myAgg.variancess[i][j]; final double varianceM = PrimitiveObjectInspectorUtils.getDouble( variances.get(j), varianceOI); - myAgg.meanss[i][j] = (n * meanN + m * meanM) / (double) (n + m); - myAgg.variancess[i][j] = (varianceN * (n - 1) + varianceM * (m - 1) + FastMath.pow( - meanN - meanM, 2) * n * m / (n + m)) - / (n + m - 1); + + if (n == 0) { + // only assign `other` into `myAgg` + myAgg.meanss[i][j] = meanM; + myAgg.variancess[i][j] = varianceM; + } else { + // merge by Chan's method + // http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf + myAgg.meanss[i][j] = (n * meanN + m * meanM) / (double) (n + m); + myAgg.variancess[i][j] = (varianceN * (n - 1) + varianceM * (m - 1) + Math.pow( + meanN - meanM, 2) * n * m / (n + m)) + / (n + m - 1); + } } } } @@ -302,25 +316,33 @@ public class SignalNoiseRatioUDAF extends AbstractGenericUDAFResolver { // calc SNR between classes each feature final double[] result = new double[nFeatures]; - final double[] sds = new double[nClasses]; // memo + final double[] sds = new double[nClasses]; // for memorization for (int i = 0; i < nFeatures; i++) { - sds[0] = FastMath.sqrt(myAgg.variancess[0][i]); + sds[0] = Math.sqrt(myAgg.variancess[0][i]); for (int j = 1; j < nClasses; j++) { - sds[j] = FastMath.sqrt(myAgg.variancess[j][i]); - if (Double.isNaN(sds[j])) { + sds[j] = Math.sqrt(myAgg.variancess[j][i]); + // `ns[j] == 0` means no feature entry belongs to class `j`, skip + if (myAgg.ns[j] == 0) { continue; } for (int k = 0; k < j; k++) { - if (Double.isNaN(sds[k])) { + // avoid comparing between classes having only single entry + if (myAgg.ns[k] == 0 || (myAgg.ns[j] == 1 && myAgg.ns[k] == 1)) { continue; } - result[i] += FastMath.abs(myAgg.meanss[j][i] - myAgg.meanss[k][i]) + + // SUM(snr) GROUP BY feature + final double snr = Math.abs(myAgg.meanss[j][i] - myAgg.meanss[k][i]) / (sds[j] + sds[k]); + // if `NaN`(when diff between means and both sds are zero, IOW, all related values are equal), + // regard feature `i` as meaningless between class `j` and `k` and skip + if (!Double.isNaN(snr)) { + result[i] += snr; // accept `Infinity` + } } } } - // SUM(snr) GROUP BY feature return WritableUtils.toWritableList(result); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4cfa4e5a/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java b/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java index 56a01d0..a4744d9 100644 --- a/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java +++ b/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java @@ -68,7 +68,7 @@ public class SignalNoiseRatioUDAFTest { } @SuppressWarnings("unchecked") - final List<DoubleWritable> resultObj = (ArrayList<DoubleWritable>) evaluator.terminate(agg); + final List<DoubleWritable> resultObj = (List<DoubleWritable>) evaluator.terminate(agg); final int size = resultObj.size(); final double[] result = new double[size]; for (int i = 0; i < size; i++) { @@ -82,7 +82,7 @@ public class SignalNoiseRatioUDAFTest { } @Test - public void snrMultipleClass() throws Exception { + public void snrMultipleClassNormalCase() throws Exception { // this test is based on *subset* of iris data set final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF(); final ObjectInspector[] OIs = new ObjectInspector[] { @@ -111,7 +111,7 @@ public class SignalNoiseRatioUDAFTest { } @SuppressWarnings("unchecked") - final List<DoubleWritable> resultObj = (ArrayList<DoubleWritable>) evaluator.terminate(agg); + final List<DoubleWritable> resultObj = (List<DoubleWritable>) evaluator.terminate(agg); final int size = resultObj.size(); final double[] result = new double[size]; for (int i = 0; i < size; i++) { @@ -125,6 +125,135 @@ public class SignalNoiseRatioUDAFTest { } @Test + public void snrMultipleClassCornerCase0() throws Exception { + final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF(); + final ObjectInspector[] OIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)}; + final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo( + OIs, false, false)); + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs); + final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer(); + evaluator.reset(agg); + + // all c0[0] and c1[0] are equal + // all c1[1] and c2[1] are equal + // all c*[2] are equal + // all c*[3] are different + final double[][] features = new double[][] { {3.5, 1.4, 0.3, 5.1}, {3.5, 1.5, 0.3, 5.2}, + {3.5, 4.5, 0.3, 7.d}, {3.5, 4.5, 0.3, 6.4}, {3.3, 4.5, 0.3, 6.3}}; + + final int[][] labels = new int[][] { {1, 0, 0}, {1, 0, 0}, // class `0` + {0, 1, 0}, {0, 1, 0}, // class `1` + {0, 0, 1}}; // class `2`, only single entry + + for (int i = 0; i < features.length; i++) { + final List<IntWritable> labelList = new ArrayList<IntWritable>(); + for (int label : labels[i]) { + labelList.add(new IntWritable(label)); + } + evaluator.iterate(agg, new Object[] {WritableUtils.toWritableList(features[i]), + labelList}); + } + + @SuppressWarnings("unchecked") + final List<DoubleWritable> resultObj = (List<DoubleWritable>) evaluator.terminate(agg); + final int size = resultObj.size(); + final double[] result = new double[size]; + for (int i = 0; i < size; i++) { + result[i] = resultObj.get(i).get(); + } + + final double[] answer = new double[] {Double.POSITIVE_INFINITY, 121.99999999999989, 0.d, + 28.761904761904734}; + + Assert.assertArrayEquals(answer, result, 1e-5); + } + + @Test + public void snrMultipleClassCornerCase1() throws Exception { + final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF(); + final ObjectInspector[] OIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)}; + final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo( + OIs, false, false)); + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs); + final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer(); + evaluator.reset(agg); + + final double[][] features = new double[][] { {5.1, 3.5, 1.4, 0.2}, {4.9, 3.d, 1.4, 0.2}, + {7.d, 3.2, 4.7, 1.4}, {6.3, 3.3, 6.d, 2.5}, {6.4, 3.2, 4.5, 1.5}}; + + // has multiple single entries + final int[][] labels = new int[][] { {1, 0, 0}, {1, 0, 0}, {1, 0, 0}, // class `0` + {0, 1, 0}, // class `1`, only single entry + {0, 0, 1}}; // class `2`, only single entry + + for (int i = 0; i < features.length; i++) { + final List<IntWritable> labelList = new ArrayList<IntWritable>(); + for (int label : labels[i]) { + labelList.add(new IntWritable(label)); + } + evaluator.iterate(agg, new Object[] {WritableUtils.toWritableList(features[i]), + labelList}); + } + + @SuppressWarnings("unchecked") + final List<DoubleWritable> resultObj = (List<DoubleWritable>) evaluator.terminate(agg); + final List<Double> result = new ArrayList<Double>(); + for (DoubleWritable dw : resultObj) { + result.add(dw.get()); + } + + Assert.assertFalse(result.contains(Double.POSITIVE_INFINITY)); + } + + @Test + public void snrMultipleClassCornerCase2() throws Exception { + final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF(); + final ObjectInspector[] OIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)}; + final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo( + OIs, false, false)); + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs); + final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer(); + evaluator.reset(agg); + + // all [0] are equal + // all [1] are equal *each class* + final double[][] features = new double[][] { {1.d, 1.d, 1.4, 0.2}, {1.d, 1.d, 1.4, 0.2}, + {1.d, 2.d, 4.7, 1.4}, {1.d, 2.d, 4.5, 1.5}, {1.d, 3.d, 6.d, 2.5}, + {1.d, 3.d, 5.1, 1.9}}; + + final int[][] labels = new int[][] { {1, 0, 0}, {1, 0, 0}, {0, 1, 0}, {0, 1, 0}, {0, 0, 1}, + {0, 0, 1}}; + + for (int i = 0; i < features.length; i++) { + final List<IntWritable> labelList = new ArrayList<IntWritable>(); + for (int label : labels[i]) { + labelList.add(new IntWritable(label)); + } + evaluator.iterate(agg, new Object[] {WritableUtils.toWritableList(features[i]), + labelList}); + } + + @SuppressWarnings("unchecked") + final List<DoubleWritable> resultObj = (List<DoubleWritable>) evaluator.terminate(agg); + final int size = resultObj.size(); + final double[] result = new double[size]; + for (int i = 0; i < size; i++) { + result[i] = resultObj.get(i).get(); + } + + final double[] answer = new double[] {0.d, Double.POSITIVE_INFINITY, 42.94949495, + 33.80952381}; + + Assert.assertArrayEquals(answer, result, 1e-5); + } + + @Test public void shouldFail0() throws Exception { expectedException.expect(UDFArgumentException.class);