refine tests
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/8e2842cf Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/8e2842cf Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/8e2842cf Branch: refs/heads/JIRA-22/pr-385 Commit: 8e2842cf8c272642feaa76bf95e8fa463b0322dc Parents: 1347de9 Author: amaya <g...@sapphire.in.net> Authored: Wed Sep 28 14:24:19 2016 +0900 Committer: amaya <g...@sapphire.in.net> Committed: Wed Sep 28 14:24:19 2016 +0900 ---------------------------------------------------------------------- .../ftvec/selection/ChiSquareUDFTest.java | 12 ++-- .../selection/SignalNoiseRatioUDAFTest.java | 71 ++++++++++++++++---- 2 files changed, 64 insertions(+), 19 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8e2842cf/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java b/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java index 38f7f57..d5880b8 100644 --- a/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java +++ b/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java @@ -69,12 +69,12 @@ public class ChiSquareUDFTest { result1[i] = Double.valueOf(((List) result[1]).get(i).toString()); } - final double[] answer0 = new double[] {10.817820878493995, 3.5944990176817315, - 116.16984746363957, 67.24482558215503}; - final double[] answer1 = new double[] {0.004476514990225833, 0.16575416718561453, 0.d, - 2.55351295663786e-15}; + // compare with results by scikit-learn + final double[] answer0 = new double[] {10.81782088, 3.59449902, 116.16984746, 67.24482759}; + final double[] answer1 = new double[] {4.47651499e-03, 1.65754167e-01, 5.94344354e-26, + 2.50017968e-15}; - Assert.assertArrayEquals(answer0, result0, 0.d); - Assert.assertArrayEquals(answer1, result1, 0.d); + Assert.assertArrayEquals(answer0, result0, 1e-5); + Assert.assertArrayEquals(answer1, result1, 1e-5); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8e2842cf/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java b/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java index 4655545..56a01d0 100644 --- a/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java +++ b/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java @@ -40,7 +40,8 @@ public class SignalNoiseRatioUDAFTest { public ExpectedException expectedException = ExpectedException.none(); @Test - public void test() throws Exception { + public void snrBinaryClass() throws Exception { + // this test is based on *subset* of iris data set final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF(); final ObjectInspector[] OIs = new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), @@ -51,20 +52,62 @@ public class SignalNoiseRatioUDAFTest { final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer(); evaluator.reset(agg); - final double[][] featuress = new double[][] { {5.1, 3.5, 1.4, 0.2}, {4.9, 3.d, 1.4, 0.2}, + final double[][] features = new double[][] { {5.1, 3.5, 1.4, 0.2}, {4.9, 3.d, 1.4, 0.2}, + {4.7, 3.2, 1.3, 0.2}, {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, + {6.9, 3.1, 4.9, 1.5}}; + + final int[][] labels = new int[][] { {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1}}; + + for (int i = 0; i < features.length; i++) { + final List<IntWritable> labelList = new ArrayList<IntWritable>(); + for (int label : labels[i]) { + labelList.add(new IntWritable(label)); + } + evaluator.iterate(agg, new Object[] {WritableUtils.toWritableList(features[i]), + labelList}); + } + + @SuppressWarnings("unchecked") + final List<DoubleWritable> resultObj = (ArrayList<DoubleWritable>) evaluator.terminate(agg); + final int size = resultObj.size(); + final double[] result = new double[size]; + for (int i = 0; i < size; i++) { + result[i] = resultObj.get(i).get(); + } + + // compare with result by numpy + final double[] answer = new double[] {4.38425236, 0.26390002, 15.83984511, 26.87005769}; + + Assert.assertArrayEquals(answer, result, 1e-5); + } + + @Test + public void snrMultipleClass() throws Exception { + // this test is based on *subset* of iris data set + final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF(); + final ObjectInspector[] OIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)}; + final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo( + OIs, false, false)); + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs); + final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer(); + evaluator.reset(agg); + + final double[][] features = new double[][] { {5.1, 3.5, 1.4, 0.2}, {4.9, 3.d, 1.4, 0.2}, {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.3, 3.3, 6.d, 2.5}, {5.8, 2.7, 5.1, 1.9}}; - final int[][] labelss = new int[][] { {1, 0, 0}, {1, 0, 0}, {0, 1, 0}, {0, 1, 0}, - {0, 0, 1}, {0, 0, 1}}; + final int[][] labels = new int[][] { {1, 0, 0}, {1, 0, 0}, {0, 1, 0}, {0, 1, 0}, {0, 0, 1}, + {0, 0, 1}}; - for (int i = 0; i < featuress.length; i++) { - final List<IntWritable> labels = new ArrayList<IntWritable>(); - for (int label : labelss[i]) { - labels.add(new IntWritable(label)); + for (int i = 0; i < features.length; i++) { + final List<IntWritable> labelList = new ArrayList<IntWritable>(); + for (int label : labels[i]) { + labelList.add(new IntWritable(label)); } - evaluator.iterate(agg, - new Object[] {WritableUtils.toWritableList(featuress[i]), labels}); + evaluator.iterate(agg, new Object[] {WritableUtils.toWritableList(features[i]), + labelList}); } @SuppressWarnings("unchecked") @@ -74,9 +117,11 @@ public class SignalNoiseRatioUDAFTest { for (int i = 0; i < size; i++) { result[i] = resultObj.get(i).get(); } - final double[] answer = new double[] {8.431818181818192, 1.3212121212121217, - 42.94949494949499, 33.80952380952378}; - Assert.assertArrayEquals(answer, result, 0.d); + + // compare with result by scikit-learn + final double[] answer = new double[] {8.43181818, 1.32121212, 42.94949495, 33.80952381}; + + Assert.assertArrayEquals(answer, result, 1e-5); } @Test