http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java b/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java index 5f8518b..d682093 100644 --- a/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java +++ b/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java @@ -18,15 +18,23 @@ */ package hivemall.smile.classification; +import hivemall.classifier.KernelExpansionPassiveAggressiveUDTF; +import hivemall.utils.codec.Base91; import hivemall.utils.lang.mutable.MutableInt; import java.io.BufferedInputStream; +import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; +import java.io.InputStreamReader; import java.net.URL; import java.text.ParseException; import java.util.ArrayList; import java.util.List; +import java.util.StringTokenizer; +import java.util.zip.GZIPInputStream; + +import javax.annotation.Nonnull; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.generic.Collector; @@ -34,6 +42,8 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; import org.junit.Assert; import org.junit.Test; @@ -43,7 +53,7 @@ import smile.data.parser.ArffParser; public class RandomForestClassifierUDTFTest { @Test - public void testIris() throws IOException, ParseException, HiveException { + public void testIrisDense() throws IOException, ParseException, HiveException { URL url = new URL( "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"); InputStream is = new BufferedInputStream(url.openStream()); @@ -85,4 +95,278 @@ public class RandomForestClassifierUDTFTest { Assert.assertEquals(49, count.getValue()); } + @Test + public void testIrisSparse() throws IOException, ParseException, HiveException { + URL url = new URL( + "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"); + InputStream is = new BufferedInputStream(url.openStream()); + + ArffParser arffParser = new ArffParser(); + arffParser.setResponseIndex(4); + + AttributeDataset iris = arffParser.parse(is); + int size = iris.size(); + double[][] x = iris.toArray(new double[size][]); + int[] y = iris.toArray(new int[size]); + + RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF(); + ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49"); + udtf.initialize(new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector), + PrimitiveObjectInspectorFactory.javaIntObjectInspector, param}); + + final List<String> xi = new ArrayList<String>(x[0].length); + for (int i = 0; i < size; i++) { + double[] row = x[i]; + for (int j = 0; j < row.length; j++) { + xi.add(j + ":" + row[j]); + } + udtf.process(new Object[] {xi, y[i]}); + xi.clear(); + } + + final MutableInt count = new MutableInt(0); + Collector collector = new Collector() { + public void collect(Object input) throws HiveException { + count.addValue(1); + } + }; + + udtf.setCollector(collector); + udtf.close(); + + Assert.assertEquals(49, count.getValue()); + } + + @Test + public void testIrisSparseDenseEquals() throws IOException, ParseException, HiveException { + String urlString = "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"; + DecisionTree.Node denseNode = getDecisionTreeFromDenseInput(urlString); + DecisionTree.Node sparseNode = getDecisionTreeFromSparseInput(urlString); + + URL url = new URL(urlString); + InputStream is = new BufferedInputStream(url.openStream()); + ArffParser arffParser = new ArffParser(); + arffParser.setResponseIndex(4); + + AttributeDataset iris = arffParser.parse(is); + int size = iris.size(); + double[][] x = iris.toArray(new double[size][]); + + int diff = 0; + for (int i = 0; i < size; i++) { + if (denseNode.predict(x[i]) != sparseNode.predict(x[i])) { + diff++; + } + } + + Assert.assertTrue("large diff " + diff + " between two predictions", diff < 10); + } + + private static DecisionTree.Node getDecisionTreeFromDenseInput(String urlString) + throws IOException, ParseException, HiveException { + URL url = new URL(urlString); + InputStream is = new BufferedInputStream(url.openStream()); + + ArffParser arffParser = new ArffParser(); + arffParser.setResponseIndex(4); + + AttributeDataset iris = arffParser.parse(is); + int size = iris.size(); + double[][] x = iris.toArray(new double[size][]); + int[] y = iris.toArray(new int[size]); + + RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF(); + ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71"); + udtf.initialize(new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector), + PrimitiveObjectInspectorFactory.javaIntObjectInspector, param}); + + final List<Double> xi = new ArrayList<Double>(x[0].length); + for (int i = 0; i < size; i++) { + for (int j = 0; j < x[i].length; j++) { + xi.add(j, x[i][j]); + } + udtf.process(new Object[] {xi, y[i]}); + xi.clear(); + } + + final Text[] placeholder = new Text[1]; + Collector collector = new Collector() { + public void collect(Object input) throws HiveException { + Object[] forward = (Object[]) input; + placeholder[0] = (Text) forward[2]; + } + }; + + udtf.setCollector(collector); + udtf.close(); + + Text modelTxt = placeholder[0]; + Assert.assertNotNull(modelTxt); + + byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength()); + DecisionTree.Node node = DecisionTree.deserializeNode(b, b.length, true); + return node; + } + + private static DecisionTree.Node getDecisionTreeFromSparseInput(String urlString) + throws IOException, ParseException, HiveException { + URL url = new URL(urlString); + InputStream is = new BufferedInputStream(url.openStream()); + + ArffParser arffParser = new ArffParser(); + arffParser.setResponseIndex(4); + + AttributeDataset iris = arffParser.parse(is); + int size = iris.size(); + double[][] x = iris.toArray(new double[size][]); + int[] y = iris.toArray(new int[size]); + + RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF(); + ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71"); + udtf.initialize(new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector), + PrimitiveObjectInspectorFactory.javaIntObjectInspector, param}); + + final List<String> xi = new ArrayList<String>(x[0].length); + for (int i = 0; i < size; i++) { + final double[] row = x[i]; + for (int j = 0; j < row.length; j++) { + xi.add(j + ":" + row[j]); + } + udtf.process(new Object[] {xi, y[i]}); + xi.clear(); + } + + final Text[] placeholder = new Text[1]; + Collector collector = new Collector() { + public void collect(Object input) throws HiveException { + Object[] forward = (Object[]) input; + placeholder[0] = (Text) forward[2]; + } + }; + + udtf.setCollector(collector); + udtf.close(); + + Text modelTxt = placeholder[0]; + Assert.assertNotNull(modelTxt); + + byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength()); + DecisionTree.Node node = DecisionTree.deserializeNode(b, b.length, true); + return node; + } + + @Test + public void testNews20MultiClassSparse() throws IOException, ParseException, HiveException { + final int numTrees = 10; + RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF(); + ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + "-stratified_sampling -seed 71 -trees " + numTrees); + udtf.initialize(new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector), + PrimitiveObjectInspectorFactory.javaIntObjectInspector, param}); + + + BufferedReader news20 = readFile("news20-multiclass.gz"); + ArrayList<String> features = new ArrayList<String>(); + String line = news20.readLine(); + while (line != null) { + StringTokenizer tokens = new StringTokenizer(line, " "); + int label = Integer.parseInt(tokens.nextToken()); + while (tokens.hasMoreTokens()) { + features.add(tokens.nextToken()); + } + Assert.assertFalse(features.isEmpty()); + udtf.process(new Object[] {features, label}); + + features.clear(); + line = news20.readLine(); + } + news20.close(); + + final MutableInt count = new MutableInt(0); + final MutableInt oobErrors = new MutableInt(0); + final MutableInt oobTests = new MutableInt(0); + Collector collector = new Collector() { + public void collect(Object input) throws HiveException { + Object[] forward = (Object[]) input; + oobErrors.addValue(((IntWritable) forward[4]).get()); + oobTests.addValue(((IntWritable) forward[5]).get()); + count.addValue(1); + } + }; + udtf.setCollector(collector); + udtf.close(); + + Assert.assertEquals(numTrees, count.getValue()); + float oobErrorRate = ((float) oobErrors.getValue()) / oobTests.getValue(); + // TODO why multi-class classification so bad?? + Assert.assertTrue("oob error rate is too high: " + oobErrorRate, oobErrorRate < 0.8); + } + + @Test + public void testNews20BinarySparse() throws IOException, ParseException, HiveException { + final int numTrees = 10; + RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF(); + ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-seed 71 -trees " + + numTrees); + udtf.initialize(new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector), + PrimitiveObjectInspectorFactory.javaIntObjectInspector, param}); + + BufferedReader news20 = readFile("news20-small.binary.gz"); + ArrayList<String> features = new ArrayList<String>(); + String line = news20.readLine(); + while (line != null) { + StringTokenizer tokens = new StringTokenizer(line, " "); + int label = Integer.parseInt(tokens.nextToken()); + if (label == -1) { + label = 0; + } + while (tokens.hasMoreTokens()) { + features.add(tokens.nextToken()); + } + if (!features.isEmpty()) { + udtf.process(new Object[] {features, label}); + features.clear(); + } + line = news20.readLine(); + } + news20.close(); + + final MutableInt count = new MutableInt(0); + final MutableInt oobErrors = new MutableInt(0); + final MutableInt oobTests = new MutableInt(0); + Collector collector = new Collector() { + public void collect(Object input) throws HiveException { + Object[] forward = (Object[]) input; + oobErrors.addValue(((IntWritable) forward[4]).get()); + oobTests.addValue(((IntWritable) forward[5]).get()); + count.addValue(1); + } + }; + udtf.setCollector(collector); + udtf.close(); + + Assert.assertEquals(numTrees, count.getValue()); + float oobErrorRate = ((float) oobErrors.getValue()) / oobTests.getValue(); + Assert.assertTrue("oob error rate is too high: " + oobErrorRate, oobErrorRate < 0.3); + } + + + @Nonnull + private static BufferedReader readFile(@Nonnull String fileName) throws IOException { + InputStream is = KernelExpansionPassiveAggressiveUDTF.class.getResourceAsStream(fileName); + if (fileName.endsWith(".gz")) { + is = new GZIPInputStream(is); + } + return new BufferedReader(new InputStreamReader(is)); + } }
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java b/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java index 20f44b3..eae625d 100644 --- a/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java +++ b/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java @@ -18,7 +18,16 @@ */ package hivemall.smile.regression; +import hivemall.math.matrix.Matrix; +import hivemall.math.matrix.builders.CSRMatrixBuilder; +import hivemall.math.matrix.dense.RowMajorDenseMatrix2d; +import hivemall.math.random.RandomNumberGeneratorFactory; import hivemall.smile.data.Attribute; +import hivemall.smile.data.Attribute.NumericAttribute; + +import java.util.Arrays; + +import javax.annotation.Nonnull; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.junit.Assert; @@ -30,7 +39,7 @@ import smile.validation.LOOCV; public class RegressionTreeTest { @Test - public void testPredict() { + public void testPredictDense() { double[][] longley = { {234.289, 235.6, 159.0, 107.608, 1947, 60.323}, {259.426, 232.5, 145.6, 108.632, 1948, 61.122}, @@ -53,10 +62,51 @@ public class RegressionTreeTest { 112.6, 114.2, 115.7, 116.9}; Attribute[] attrs = new Attribute[longley[0].length]; - for (int i = 0; i < attrs.length; i++) { - attrs[i] = new Attribute.NumericAttribute(i); + Arrays.fill(attrs, new NumericAttribute()); + + int n = longley.length; + LOOCV loocv = new LOOCV(n); + double rss = 0.0; + for (int i = 0; i < n; i++) { + double[][] trainx = Math.slice(longley, loocv.train[i]); + double[] trainy = Math.slice(y, loocv.train[i]); + int maxLeafs = 10; + RegressionTree tree = new RegressionTree(attrs, matrix(trainx, true), trainy, maxLeafs, + RandomNumberGeneratorFactory.createPRNG(i)); + + double r = y[loocv.test[i]] - tree.predict(longley[loocv.test[i]]); + rss += r * r; } + Assert.assertTrue("MSE = " + (rss / n), (rss / n) < 42); + } + + @Test + public void testPredictSparse() { + + double[][] longley = { {234.289, 235.6, 159.0, 107.608, 1947, 60.323}, + {259.426, 232.5, 145.6, 108.632, 1948, 61.122}, + {258.054, 368.2, 161.6, 109.773, 1949, 60.171}, + {284.599, 335.1, 165.0, 110.929, 1950, 61.187}, + {328.975, 209.9, 309.9, 112.075, 1951, 63.221}, + {346.999, 193.2, 359.4, 113.270, 1952, 63.639}, + {365.385, 187.0, 354.7, 115.094, 1953, 64.989}, + {363.112, 357.8, 335.0, 116.219, 1954, 63.761}, + {397.469, 290.4, 304.8, 117.388, 1955, 66.019}, + {419.180, 282.2, 285.7, 118.734, 1956, 67.857}, + {442.769, 293.6, 279.8, 120.445, 1957, 68.169}, + {444.546, 468.1, 263.7, 121.950, 1958, 66.513}, + {482.704, 381.3, 255.2, 123.366, 1959, 68.655}, + {502.601, 393.1, 251.4, 125.368, 1960, 69.564}, + {518.173, 480.6, 257.2, 127.852, 1961, 69.331}, + {554.894, 400.7, 282.7, 130.081, 1962, 70.551}}; + + double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, + 112.6, 114.2, 115.7, 116.9}; + + Attribute[] attrs = new Attribute[longley[0].length]; + Arrays.fill(attrs, new NumericAttribute()); + int n = longley.length; LOOCV loocv = new LOOCV(n); double rss = 0.0; @@ -64,8 +114,8 @@ public class RegressionTreeTest { double[][] trainx = Math.slice(longley, loocv.train[i]); double[] trainy = Math.slice(y, loocv.train[i]); int maxLeafs = 10; - smile.math.Random rand = new smile.math.Random(i); - RegressionTree tree = new RegressionTree(attrs, trainx, trainy, maxLeafs, rand); + RegressionTree tree = new RegressionTree(attrs, matrix(trainx, false), trainy, + maxLeafs, RandomNumberGeneratorFactory.createPRNG(i)); double r = y[loocv.test[i]] - tree.predict(longley[loocv.test[i]]); rss += r * r; @@ -98,9 +148,7 @@ public class RegressionTreeTest { 112.6, 114.2, 115.7, 116.9}; Attribute[] attrs = new Attribute[longley[0].length]; - for (int i = 0; i < attrs.length; i++) { - attrs[i] = new Attribute.NumericAttribute(i); - } + Arrays.fill(attrs, new NumericAttribute()); int n = longley.length; LOOCV loocv = new LOOCV(n); @@ -108,7 +156,7 @@ public class RegressionTreeTest { double[][] trainx = Math.slice(longley, loocv.train[i]); double[] trainy = Math.slice(y, loocv.train[i]); int maxLeafs = Integer.MAX_VALUE; - RegressionTree tree = new RegressionTree(attrs, trainx, trainy, maxLeafs); + RegressionTree tree = new RegressionTree(attrs, matrix(trainx, true), trainy, maxLeafs); byte[] b = tree.predictSerCodegen(true); RegressionTree.Node node = RegressionTree.deserializeNode(b, b.length, true); @@ -119,4 +167,19 @@ public class RegressionTreeTest { Assert.assertEquals(expected, actual, 0.d); } } + + @Nonnull + private static Matrix matrix(@Nonnull final double[][] x, boolean dense) { + if (dense) { + return new RowMajorDenseMatrix2d(x, x[0].length); + } else { + int numRows = x.length; + CSRMatrixBuilder builder = new CSRMatrixBuilder(1024); + for (int i = 0; i < numRows; i++) { + builder.nextRow(x[i]); + } + return builder.buildMatrix(); + } + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java b/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java index 504ea86..65feeeb 100644 --- a/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java +++ b/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java @@ -18,13 +18,12 @@ */ package hivemall.smile.tools; -import static org.junit.Assert.assertEquals; -import hivemall.smile.ModelType; +import hivemall.math.matrix.dense.RowMajorDenseMatrix2d; import hivemall.smile.classification.DecisionTree; import hivemall.smile.data.Attribute; import hivemall.smile.regression.RegressionTree; import hivemall.smile.utils.SmileExtUtils; -import hivemall.smile.vm.StackMachine; +import hivemall.utils.codec.Base91; import hivemall.utils.lang.ArrayUtils; import java.io.BufferedInputStream; @@ -42,6 +41,8 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.junit.Assert; import org.junit.Test; import smile.data.AttributeDataset; @@ -49,7 +50,7 @@ import smile.data.parser.ArffParser; import smile.math.Math; import smile.validation.CrossValidation; import smile.validation.LOOCV; -import smile.validation.Validation; +import smile.validation.RMSE; public class TreePredictUDFTest { private static final boolean DEBUG = false; @@ -76,8 +77,9 @@ public class TreePredictUDFTest { int[] trainy = Math.slice(y, loocv.train[i]); Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes()); - DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 4); - assertEquals(tree.predict(x[loocv.test[i]]), evalPredict(tree, x[loocv.test[i]])); + DecisionTree tree = new DecisionTree(attrs, new RowMajorDenseMatrix2d(trainx, + x[0].length), trainy, 4); + Assert.assertEquals(tree.predict(x[loocv.test[i]]), evalPredict(tree, x[loocv.test[i]])); } } @@ -103,10 +105,11 @@ public class TreePredictUDFTest { double[][] testx = Math.slice(datax, cv.test[i]); Attribute[] attrs = SmileExtUtils.convertAttributeTypes(data.attributes()); - RegressionTree tree = new RegressionTree(attrs, trainx, trainy, 20); + RegressionTree tree = new RegressionTree(attrs, new RowMajorDenseMatrix2d(trainx, + trainx[0].length), trainy, 20); for (int j = 0; j < testx.length; j++) { - assertEquals(tree.predict(testx[j]), evalPredict(tree, testx[j]), 1.0); + Assert.assertEquals(tree.predict(testx[j]), evalPredict(tree, testx[j]), 1.0); } } } @@ -142,52 +145,60 @@ public class TreePredictUDFTest { } Attribute[] attrs = SmileExtUtils.convertAttributeTypes(data.attributes()); - RegressionTree tree = new RegressionTree(attrs, trainx, trainy, 20); - debugPrint(String.format("RMSE = %.4f\n", Validation.test(tree, testx, testy))); + RegressionTree tree = new RegressionTree(attrs, new RowMajorDenseMatrix2d(trainx, + trainx[0].length), trainy, 20); + debugPrint(String.format("RMSE = %.4f\n", rmse(tree, testx, testy))); for (int i = m; i < n; i++) { - assertEquals(tree.predict(testx[i - m]), evalPredict(tree, testx[i - m]), 1.0); + Assert.assertEquals(tree.predict(testx[i - m]), evalPredict(tree, testx[i - m]), 1.0); } } + private static <T> double rmse(RegressionTree regression, double[][] x, double[] y) { + final int n = x.length; + final double[] predictions = new double[n]; + for (int i = 0; i < n; i++) { + predictions[i] = regression.predict(x[i]); + } + return new RMSE().measure(y, predictions); + } + private static int evalPredict(DecisionTree tree, double[] x) throws HiveException, IOException { - String opScript = tree.predictOpCodegen(StackMachine.SEP); - debugPrint(opScript); + byte[] b = tree.predictSerCodegen(true); + byte[] encoded = Base91.encode(b); + Text model = new Text(encoded); TreePredictUDF udf = new TreePredictUDF(); udf.initialize(new ObjectInspector[] { PrimitiveObjectInspectorFactory.javaStringObjectInspector, - PrimitiveObjectInspectorFactory.javaIntObjectInspector, - PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.writableStringObjectInspector, ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector), ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, true)}); DeferredObject[] arguments = new DeferredObject[] {new DeferredJavaObject("model_id#1"), - new DeferredJavaObject(ModelType.opscode.getId()), - new DeferredJavaObject(opScript), new DeferredJavaObject(ArrayUtils.toList(x)), + new DeferredJavaObject(model), new DeferredJavaObject(ArrayUtils.toList(x)), new DeferredJavaObject(true)}; - IntWritable result = (IntWritable) udf.evaluate(arguments); + Object[] result = (Object[]) udf.evaluate(arguments); udf.close(); - return result.get(); + return ((IntWritable) result[0]).get(); } private static double evalPredict(RegressionTree tree, double[] x) throws HiveException, IOException { - String opScript = tree.predictOpCodegen(StackMachine.SEP); - debugPrint(opScript); + byte[] b = tree.predictSerCodegen(true); + byte[] encoded = Base91.encode(b); + Text model = new Text(encoded); TreePredictUDF udf = new TreePredictUDF(); udf.initialize(new ObjectInspector[] { PrimitiveObjectInspectorFactory.javaStringObjectInspector, - PrimitiveObjectInspectorFactory.javaIntObjectInspector, - PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.writableStringObjectInspector, ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector), ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, false)}); DeferredObject[] arguments = new DeferredObject[] {new DeferredJavaObject("model_id#1"), - new DeferredJavaObject(ModelType.opscode.getId()), - new DeferredJavaObject(opScript), new DeferredJavaObject(ArrayUtils.toList(x)), + new DeferredJavaObject(model), new DeferredJavaObject(ArrayUtils.toList(x)), new DeferredJavaObject(false)}; DoubleWritable result = (DoubleWritable) udf.evaluate(arguments); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/smile/vm/StackMachineTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/smile/vm/StackMachineTest.java b/core/src/test/java/hivemall/smile/vm/StackMachineTest.java deleted file mode 100644 index 4a2dcd8..0000000 --- a/core/src/test/java/hivemall/smile/vm/StackMachineTest.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.smile.vm; - -import static org.junit.Assert.assertEquals; -import hivemall.utils.io.IOUtils; - -import java.io.BufferedInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.net.URL; -import java.text.ParseException; -import java.util.ArrayList; - -import org.apache.hadoop.hive.ql.metadata.HiveException; -import org.junit.Assert; -import org.junit.Test; - -public class StackMachineTest { - private static final boolean DEBUG = false; - - @Test - public void testFindInfinteLoop() throws IOException, ParseException, HiveException, - VMRuntimeException { - // Sample of machine code having infinite loop - ArrayList<String> opScript = new ArrayList<String>(); - opScript.add("push 2.0"); - opScript.add("push 1.0"); - opScript.add("iflt 0"); - opScript.add("push 1"); - opScript.add("call end"); - debugPrint(opScript); - double[] x = new double[0]; - StackMachine sm = new StackMachine(); - try { - sm.run(opScript, x); - Assert.fail("VMRuntimeException is expected"); - } catch (VMRuntimeException ex) { - assertEquals("There is a infinite loop in the Machine code.", ex.getMessage()); - } - } - - @Test - public void testLargeOpcodes() throws IOException, ParseException, HiveException, - VMRuntimeException { - URL url = new URL( - "https://gist.githubusercontent.com/myui/b1a8e588f5750e3b658c/raw/a4074d37400dab2b13a2f43d81f5166188d3461a/vmtest01.txt"); - InputStream is = new BufferedInputStream(url.openStream()); - String opScript = IOUtils.toString(is); - - StackMachine sm = new StackMachine(); - sm.compile(opScript); - - double[] x1 = new double[] {36, 2, 1, 2, 0, 436, 1, 0, 0, 13, 0, 567, 1, 595, 2, 1}; - sm.eval(x1); - assertEquals(0.d, sm.getResult().doubleValue(), 0d); - - double[] x2 = {31, 2, 1, 2, 0, 354, 1, 0, 0, 30, 0, 502, 1, 9, 2, 2}; - sm.eval(x2); - assertEquals(1.d, sm.getResult().doubleValue(), 0d); - - double[] x3 = {39, 0, 0, 0, 0, 1756, 0, 0, 0, 3, 0, 939, 1, 0, 0, 0}; - sm.eval(x3); - assertEquals(0.d, sm.getResult().doubleValue(), 0d); - } - - private static void debugPrint(Object msg) { - if (DEBUG) { - System.out.println(msg); - } - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/DoubleArray3DTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/DoubleArray3DTest.java b/core/src/test/java/hivemall/utils/collections/DoubleArray3DTest.java deleted file mode 100644 index 177a345..0000000 --- a/core/src/test/java/hivemall/utils/collections/DoubleArray3DTest.java +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.utils.collections; - -import java.util.Random; - -import org.junit.Assert; -import org.junit.Test; - -public class DoubleArray3DTest { - - @Test - public void test() { - final int size_i = 50, size_j = 50, size_k = 5; - - final DoubleArray3D mdarray = new DoubleArray3D(); - mdarray.configure(size_i, size_j, size_k); - - final Random rand = new Random(31L); - final double[][][] data = new double[size_i][size_j][size_j]; - for (int i = 0; i < size_i; i++) { - for (int j = 0; j < size_j; j++) { - for (int k = 0; k < size_k; k++) { - double v = rand.nextDouble(); - data[i][j][k] = v; - mdarray.set(i, j, k, v); - } - } - } - - Assert.assertEquals(size_i * size_j * size_k, mdarray.getSize()); - - for (int i = 0; i < size_i; i++) { - for (int j = 0; j < size_j; j++) { - for (int k = 0; k < size_k; k++) { - Assert.assertEquals(data[i][j][k], mdarray.get(i, j, k), 0.d); - } - } - } - } - - @Test - public void testConfigureExpand() { - int size_i = 50, size_j = 50, size_k = 5; - - final DoubleArray3D mdarray = new DoubleArray3D(); - mdarray.configure(size_i, size_j, size_k); - - final Random rand = new Random(31L); - for (int i = 0; i < size_i; i++) { - for (int j = 0; j < size_j; j++) { - for (int k = 0; k < size_k; k++) { - double v = rand.nextDouble(); - mdarray.set(i, j, k, v); - } - } - } - - size_i = 101; - size_j = 101; - size_k = 11; - mdarray.configure(size_i, size_j, size_k); - Assert.assertEquals(size_i * size_j * size_k, mdarray.getCapacity()); - Assert.assertEquals(size_i * size_j * size_k, mdarray.getSize()); - - final double[][][] data = new double[size_i][size_j][size_j]; - for (int i = 0; i < size_i; i++) { - for (int j = 0; j < size_j; j++) { - for (int k = 0; k < size_k; k++) { - double v = rand.nextDouble(); - data[i][j][k] = v; - mdarray.set(i, j, k, v); - } - } - } - - for (int i = 0; i < size_i; i++) { - for (int j = 0; j < size_j; j++) { - for (int k = 0; k < size_k; k++) { - Assert.assertEquals(data[i][j][k], mdarray.get(i, j, k), 0.d); - } - } - } - } - - @Test - public void testConfigureShrink() { - int size_i = 50, size_j = 50, size_k = 5; - - final DoubleArray3D mdarray = new DoubleArray3D(); - mdarray.configure(size_i, size_j, size_k); - - final Random rand = new Random(31L); - for (int i = 0; i < size_i; i++) { - for (int j = 0; j < size_j; j++) { - for (int k = 0; k < size_k; k++) { - double v = rand.nextDouble(); - mdarray.set(i, j, k, v); - } - } - } - - int capacity = mdarray.getCapacity(); - size_i = 49; - size_j = 49; - size_k = 4; - mdarray.configure(size_i, size_j, size_k); - Assert.assertEquals(capacity, mdarray.getCapacity()); - Assert.assertEquals(size_i * size_j * size_k, mdarray.getSize()); - - final double[][][] data = new double[size_i][size_j][size_j]; - for (int i = 0; i < size_i; i++) { - for (int j = 0; j < size_j; j++) { - for (int k = 0; k < size_k; k++) { - double v = rand.nextDouble(); - data[i][j][k] = v; - mdarray.set(i, j, k, v); - } - } - } - - for (int i = 0; i < size_i; i++) { - for (int j = 0; j < size_j; j++) { - for (int k = 0; k < size_k; k++) { - Assert.assertEquals(data[i][j][k], mdarray.get(i, j, k), 0.d); - } - } - } - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/DoubleArrayTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/DoubleArrayTest.java b/core/src/test/java/hivemall/utils/collections/DoubleArrayTest.java deleted file mode 100644 index 72e76e8..0000000 --- a/core/src/test/java/hivemall/utils/collections/DoubleArrayTest.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.utils.collections; - -import org.junit.Assert; -import org.junit.Test; - -public class DoubleArrayTest { - - @Test - public void testSparseDoubleArrayToArray() { - SparseDoubleArray array = new SparseDoubleArray(3); - for (int i = 0; i < 10; i++) { - array.put(i, 10 + i); - } - Assert.assertEquals(10, array.size()); - Assert.assertEquals(10, array.toArray(false).length); - - double[] copied = array.toArray(true); - Assert.assertEquals(10, copied.length); - for (int i = 0; i < 10; i++) { - Assert.assertEquals(10 + i, copied[i], 0.d); - } - } - - @Test - public void testSparseDoubleArrayClear() { - SparseDoubleArray array = new SparseDoubleArray(3); - for (int i = 0; i < 10; i++) { - array.put(i, 10 + i); - } - array.clear(); - Assert.assertEquals(0, array.size()); - Assert.assertEquals(0, array.get(0), 0.d); - for (int i = 0; i < 5; i++) { - array.put(i, 100 + i); - } - Assert.assertEquals(5, array.size()); - for (int i = 0; i < 5; i++) { - Assert.assertEquals(100 + i, array.get(i), 0.d); - } - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/Int2FloatOpenHashMapTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/Int2FloatOpenHashMapTest.java b/core/src/test/java/hivemall/utils/collections/Int2FloatOpenHashMapTest.java deleted file mode 100644 index 8a8a68d..0000000 --- a/core/src/test/java/hivemall/utils/collections/Int2FloatOpenHashMapTest.java +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.utils.collections; - -import org.junit.Assert; -import org.junit.Test; - -public class Int2FloatOpenHashMapTest { - - @Test - public void testSize() { - Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(16384); - map.put(1, 3.f); - Assert.assertEquals(3.f, map.get(1), 0.d); - map.put(1, 5.f); - Assert.assertEquals(5.f, map.get(1), 0.d); - Assert.assertEquals(1, map.size()); - } - - @Test - public void testDefaultReturnValue() { - Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(16384); - Assert.assertEquals(0, map.size()); - Assert.assertEquals(-1.f, map.get(1), 0.d); - float ret = Float.MIN_VALUE; - map.defaultReturnValue(ret); - Assert.assertEquals(ret, map.get(1), 0.d); - } - - @Test - public void testPutAndGet() { - Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(16384); - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - Assert.assertEquals(-1.f, map.put(i, Float.valueOf(i + 0.1f)), 0.d); - } - Assert.assertEquals(numEntries, map.size()); - for (int i = 0; i < numEntries; i++) { - Float v = map.get(i); - Assert.assertEquals(i + 0.1f, v.floatValue(), 0.d); - } - } - - @Test - public void testIterator() { - Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(1000); - Int2FloatOpenHashTable.IMapIterator itor = map.entries(); - Assert.assertFalse(itor.hasNext()); - - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - Assert.assertEquals(-1.f, map.put(i, Float.valueOf(i + 0.1f)), 0.d); - } - Assert.assertEquals(numEntries, map.size()); - - itor = map.entries(); - Assert.assertTrue(itor.hasNext()); - while (itor.hasNext()) { - Assert.assertFalse(itor.next() == -1); - int k = itor.getKey(); - Float v = itor.getValue(); - Assert.assertEquals(k + 0.1f, v.floatValue(), 0.d); - } - Assert.assertEquals(-1, itor.next()); - } - - @Test - public void testIterator2() { - Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(100); - map.put(33, 3.16f); - - Int2FloatOpenHashTable.IMapIterator itor = map.entries(); - Assert.assertTrue(itor.hasNext()); - Assert.assertNotEquals(-1, itor.next()); - Assert.assertEquals(33, itor.getKey()); - Assert.assertEquals(3.16f, itor.getValue(), 0.d); - Assert.assertEquals(-1, itor.next()); - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/Int2LongOpenHashMapTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/Int2LongOpenHashMapTest.java b/core/src/test/java/hivemall/utils/collections/Int2LongOpenHashMapTest.java deleted file mode 100644 index 1186bdf..0000000 --- a/core/src/test/java/hivemall/utils/collections/Int2LongOpenHashMapTest.java +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.utils.collections; - -import hivemall.utils.lang.ObjectUtils; - -import java.io.IOException; - -import org.junit.Assert; -import org.junit.Test; - -public class Int2LongOpenHashMapTest { - - @Test - public void testSize() { - Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384); - map.put(1, 3L); - Assert.assertEquals(3L, map.get(1)); - map.put(1, 5L); - Assert.assertEquals(5L, map.get(1)); - Assert.assertEquals(1, map.size()); - } - - @Test - public void testDefaultReturnValue() { - Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384); - Assert.assertEquals(0, map.size()); - Assert.assertEquals(-1L, map.get(1)); - long ret = Long.MIN_VALUE; - map.defaultReturnValue(ret); - Assert.assertEquals(ret, map.get(1)); - } - - @Test - public void testPutAndGet() { - Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384); - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - Assert.assertEquals(-1L, map.put(i, i)); - } - Assert.assertEquals(numEntries, map.size()); - for (int i = 0; i < numEntries; i++) { - long v = map.get(i); - Assert.assertEquals(i, v); - } - } - - @Test - public void testSerde() throws IOException, ClassNotFoundException { - Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384); - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - Assert.assertEquals(-1L, map.put(i, i)); - } - - byte[] b = ObjectUtils.toCompressedBytes(map); - map = new Int2LongOpenHashTable(16384); - ObjectUtils.readCompressedObject(b, map); - - Assert.assertEquals(numEntries, map.size()); - for (int i = 0; i < numEntries; i++) { - long v = map.get(i); - Assert.assertEquals(i, v); - } - } - - @Test - public void testIterator() { - Int2LongOpenHashTable map = new Int2LongOpenHashTable(1000); - Int2LongOpenHashTable.IMapIterator itor = map.entries(); - Assert.assertFalse(itor.hasNext()); - - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - Assert.assertEquals(-1L, map.put(i, i)); - } - Assert.assertEquals(numEntries, map.size()); - - itor = map.entries(); - Assert.assertTrue(itor.hasNext()); - while (itor.hasNext()) { - Assert.assertFalse(itor.next() == -1); - int k = itor.getKey(); - long v = itor.getValue(); - Assert.assertEquals(k, v); - } - Assert.assertEquals(-1, itor.next()); - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/IntArrayTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/IntArrayTest.java b/core/src/test/java/hivemall/utils/collections/IntArrayTest.java deleted file mode 100644 index 42852ea..0000000 --- a/core/src/test/java/hivemall/utils/collections/IntArrayTest.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.utils.collections; - -import org.junit.Assert; -import org.junit.Test; - -public class IntArrayTest { - - @Test - public void testFixedIntArrayToArray() { - FixedIntArray array = new FixedIntArray(11); - for (int i = 0; i < 10; i++) { - array.put(i, 10 + i); - } - Assert.assertEquals(11, array.size()); - Assert.assertEquals(11, array.toArray(false).length); - - int[] copied = array.toArray(true); - Assert.assertEquals(11, copied.length); - for (int i = 0; i < 10; i++) { - Assert.assertEquals(10 + i, copied[i]); - } - } - - @Test - public void testSparseIntArrayToArray() { - SparseIntArray array = new SparseIntArray(3); - for (int i = 0; i < 10; i++) { - array.put(i, 10 + i); - } - Assert.assertEquals(10, array.size()); - Assert.assertEquals(10, array.toArray(false).length); - - int[] copied = array.toArray(true); - Assert.assertEquals(10, copied.length); - for (int i = 0; i < 10; i++) { - Assert.assertEquals(10 + i, copied[i]); - } - } - - @Test - public void testSparseIntArrayClear() { - SparseIntArray array = new SparseIntArray(3); - for (int i = 0; i < 10; i++) { - array.put(i, 10 + i); - } - array.clear(); - Assert.assertEquals(0, array.size()); - Assert.assertEquals(0, array.get(0)); - for (int i = 0; i < 5; i++) { - array.put(i, 100 + i); - } - Assert.assertEquals(5, array.size()); - for (int i = 0; i < 5; i++) { - Assert.assertEquals(100 + i, array.get(i)); - } - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/IntOpenHashMapTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/IntOpenHashMapTest.java b/core/src/test/java/hivemall/utils/collections/IntOpenHashMapTest.java deleted file mode 100644 index 29a5a81..0000000 --- a/core/src/test/java/hivemall/utils/collections/IntOpenHashMapTest.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.utils.collections; - -import org.junit.Assert; -import org.junit.Test; - -public class IntOpenHashMapTest { - - @Test - public void testSize() { - IntOpenHashMap<Float> map = new IntOpenHashMap<Float>(16384); - map.put(1, Float.valueOf(3.f)); - Assert.assertEquals(Float.valueOf(3.f), map.get(1)); - map.put(1, Float.valueOf(5.f)); - Assert.assertEquals(Float.valueOf(5.f), map.get(1)); - Assert.assertEquals(1, map.size()); - } - - @Test - public void testPutAndGet() { - IntOpenHashMap<Integer> map = new IntOpenHashMap<Integer>(16384); - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - Assert.assertNull(map.put(i, i)); - } - Assert.assertEquals(numEntries, map.size()); - for (int i = 0; i < numEntries; i++) { - Integer v = map.get(i); - Assert.assertEquals(i, v.intValue()); - } - } - - @Test - public void testIterator() { - IntOpenHashMap<Integer> map = new IntOpenHashMap<Integer>(1000); - IntOpenHashMap.IMapIterator<Integer> itor = map.entries(); - Assert.assertFalse(itor.hasNext()); - - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - Assert.assertNull(map.put(i, i)); - } - Assert.assertEquals(numEntries, map.size()); - - itor = map.entries(); - Assert.assertTrue(itor.hasNext()); - while (itor.hasNext()) { - Assert.assertFalse(itor.next() == -1); - int k = itor.getKey(); - Integer v = itor.getValue(); - Assert.assertEquals(k, v.intValue()); - } - Assert.assertEquals(-1, itor.next()); - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/IntOpenHashTableTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/IntOpenHashTableTest.java b/core/src/test/java/hivemall/utils/collections/IntOpenHashTableTest.java deleted file mode 100644 index 3babb3d..0000000 --- a/core/src/test/java/hivemall/utils/collections/IntOpenHashTableTest.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.utils.collections; - -import org.junit.Assert; -import org.junit.Test; - -public class IntOpenHashTableTest { - - @Test - public void testSize() { - IntOpenHashTable<Float> map = new IntOpenHashTable<Float>(16384); - map.put(1, Float.valueOf(3.f)); - Assert.assertEquals(Float.valueOf(3.f), map.get(1)); - map.put(1, Float.valueOf(5.f)); - Assert.assertEquals(Float.valueOf(5.f), map.get(1)); - Assert.assertEquals(1, map.size()); - } - - @Test - public void testPutAndGet() { - IntOpenHashTable<Integer> map = new IntOpenHashTable<Integer>(16384); - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - Assert.assertNull(map.put(i, i)); - } - Assert.assertEquals(numEntries, map.size()); - for (int i = 0; i < numEntries; i++) { - Integer v = map.get(i); - Assert.assertEquals(i, v.intValue()); - } - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/OpenHashMapTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/OpenHashMapTest.java b/core/src/test/java/hivemall/utils/collections/OpenHashMapTest.java deleted file mode 100644 index e3cc018..0000000 --- a/core/src/test/java/hivemall/utils/collections/OpenHashMapTest.java +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.utils.collections; - -import hivemall.utils.lang.mutable.MutableInt; - -import java.util.Map; - -import org.junit.Assert; -import org.junit.Test; - -public class OpenHashMapTest { - - @Test - public void testPutAndGet() { - Map<Object, Object> map = new OpenHashMap<Object, Object>(16384); - final int numEntries = 5000000; - for (int i = 0; i < numEntries; i++) { - map.put(Integer.toString(i), i); - } - Assert.assertEquals(numEntries, map.size()); - for (int i = 0; i < numEntries; i++) { - Object v = map.get(Integer.toString(i)); - Assert.assertEquals(i, v); - } - map.put(Integer.toString(1), Integer.MAX_VALUE); - Assert.assertEquals(Integer.MAX_VALUE, map.get(Integer.toString(1))); - Assert.assertEquals(numEntries, map.size()); - } - - @Test - public void testIterator() { - OpenHashMap<String, Integer> map = new OpenHashMap<String, Integer>(1000); - IMapIterator<String, Integer> itor = map.entries(); - Assert.assertFalse(itor.hasNext()); - - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - map.put(Integer.toString(i), i); - } - - itor = map.entries(); - Assert.assertTrue(itor.hasNext()); - while (itor.hasNext()) { - Assert.assertFalse(itor.next() == -1); - String k = itor.getKey(); - Integer v = itor.getValue(); - Assert.assertEquals(Integer.valueOf(k), v); - } - Assert.assertEquals(-1, itor.next()); - } - - @Test - public void testIteratorGetProbe() { - OpenHashMap<String, MutableInt> map = new OpenHashMap<String, MutableInt>(100); - IMapIterator<String, MutableInt> itor = map.entries(); - Assert.assertFalse(itor.hasNext()); - - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - map.put(Integer.toString(i), new MutableInt(i)); - } - - final MutableInt probe = new MutableInt(); - itor = map.entries(); - Assert.assertTrue(itor.hasNext()); - while (itor.hasNext()) { - Assert.assertFalse(itor.next() == -1); - String k = itor.getKey(); - itor.getValue(probe); - Assert.assertEquals(Integer.valueOf(k).intValue(), probe.intValue()); - } - Assert.assertEquals(-1, itor.next()); - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/OpenHashTableTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/OpenHashTableTest.java b/core/src/test/java/hivemall/utils/collections/OpenHashTableTest.java deleted file mode 100644 index d5a465c..0000000 --- a/core/src/test/java/hivemall/utils/collections/OpenHashTableTest.java +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.utils.collections; - -import hivemall.utils.lang.ObjectUtils; -import hivemall.utils.lang.mutable.MutableInt; - -import java.io.IOException; - -import org.junit.Assert; -import org.junit.Test; - -public class OpenHashTableTest { - - @Test - public void testPutAndGet() { - OpenHashTable<Object, Object> map = new OpenHashTable<Object, Object>(16384); - final int numEntries = 5000000; - for (int i = 0; i < numEntries; i++) { - map.put(Integer.toString(i), i); - } - Assert.assertEquals(numEntries, map.size()); - for (int i = 0; i < numEntries; i++) { - Object v = map.get(Integer.toString(i)); - Assert.assertEquals(i, v); - } - map.put(Integer.toString(1), Integer.MAX_VALUE); - Assert.assertEquals(Integer.MAX_VALUE, map.get(Integer.toString(1))); - Assert.assertEquals(numEntries, map.size()); - } - - @Test - public void testIterator() { - OpenHashTable<String, Integer> map = new OpenHashTable<String, Integer>(1000); - IMapIterator<String, Integer> itor = map.entries(); - Assert.assertFalse(itor.hasNext()); - - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - map.put(Integer.toString(i), i); - } - - itor = map.entries(); - Assert.assertTrue(itor.hasNext()); - while (itor.hasNext()) { - Assert.assertFalse(itor.next() == -1); - String k = itor.getKey(); - Integer v = itor.getValue(); - Assert.assertEquals(Integer.valueOf(k), v); - } - Assert.assertEquals(-1, itor.next()); - } - - @Test - public void testIteratorGetProbe() { - OpenHashTable<String, MutableInt> map = new OpenHashTable<String, MutableInt>(100); - IMapIterator<String, MutableInt> itor = map.entries(); - Assert.assertFalse(itor.hasNext()); - - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - map.put(Integer.toString(i), new MutableInt(i)); - } - - final MutableInt probe = new MutableInt(); - itor = map.entries(); - Assert.assertTrue(itor.hasNext()); - while (itor.hasNext()) { - Assert.assertFalse(itor.next() == -1); - String k = itor.getKey(); - itor.getValue(probe); - Assert.assertEquals(Integer.valueOf(k).intValue(), probe.intValue()); - } - Assert.assertEquals(-1, itor.next()); - } - - @Test - public void testSerDe() throws IOException, ClassNotFoundException { - OpenHashTable<Object, Object> map = new OpenHashTable<Object, Object>(16384); - final int numEntries = 100000; - for (int i = 0; i < numEntries; i++) { - map.put(Integer.toString(i), i); - } - - byte[] serialized = ObjectUtils.toBytes(map); - map = new OpenHashTable<Object, Object>(); - ObjectUtils.readObject(serialized, map); - - Assert.assertEquals(numEntries, map.size()); - for (int i = 0; i < numEntries; i++) { - Object v = map.get(Integer.toString(i)); - Assert.assertEquals(i, v); - } - map.put(Integer.toString(1), Integer.MAX_VALUE); - Assert.assertEquals(Integer.MAX_VALUE, map.get(Integer.toString(1))); - Assert.assertEquals(numEntries, map.size()); - } - - - @Test - public void testCompressedSerDe() throws IOException, ClassNotFoundException { - OpenHashTable<Object, Object> map = new OpenHashTable<Object, Object>(16384); - final int numEntries = 100000; - for (int i = 0; i < numEntries; i++) { - map.put(Integer.toString(i), i); - } - - byte[] serialized = ObjectUtils.toCompressedBytes(map); - map = new OpenHashTable<Object, Object>(); - ObjectUtils.readCompressedObject(serialized, map); - - Assert.assertEquals(numEntries, map.size()); - for (int i = 0; i < numEntries; i++) { - Object v = map.get(Integer.toString(i)); - Assert.assertEquals(i, v); - } - map.put(Integer.toString(1), Integer.MAX_VALUE); - Assert.assertEquals(Integer.MAX_VALUE, map.get(Integer.toString(1))); - Assert.assertEquals(numEntries, map.size()); - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/SparseIntArrayTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/SparseIntArrayTest.java b/core/src/test/java/hivemall/utils/collections/SparseIntArrayTest.java deleted file mode 100644 index 68d0f6d..0000000 --- a/core/src/test/java/hivemall/utils/collections/SparseIntArrayTest.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.utils.collections; - -import java.util.Random; - -import org.junit.Assert; -import org.junit.Test; - -public class SparseIntArrayTest { - - @Test - public void testDense() { - int size = 1000; - Random rand = new Random(31); - int[] expected = new int[size]; - IntArray actual = new SparseIntArray(10); - for (int i = 0; i < size; i++) { - int r = rand.nextInt(size); - expected[i] = r; - actual.put(i, r); - } - for (int i = 0; i < size; i++) { - Assert.assertEquals(expected[i], actual.get(i)); - } - } - - @Test - public void testSparse() { - int size = 1000; - Random rand = new Random(31); - int[] expected = new int[size]; - SparseIntArray actual = new SparseIntArray(10); - for (int i = 0; i < size; i++) { - int key = rand.nextInt(size); - int v = rand.nextInt(); - expected[key] = v; - actual.put(key, v); - } - for (int i = 0; i < actual.size(); i++) { - int key = actual.keyAt(i); - Assert.assertEquals(expected[key], actual.get(key, 0)); - } - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/arrays/DoubleArray3DTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/arrays/DoubleArray3DTest.java b/core/src/test/java/hivemall/utils/collections/arrays/DoubleArray3DTest.java new file mode 100644 index 0000000..4fdb43e --- /dev/null +++ b/core/src/test/java/hivemall/utils/collections/arrays/DoubleArray3DTest.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.collections.arrays; + +import hivemall.utils.collections.arrays.DoubleArray3D; + +import java.util.Random; + +import org.junit.Assert; +import org.junit.Test; + +public class DoubleArray3DTest { + + @Test + public void test() { + final int size_i = 50, size_j = 50, size_k = 5; + + final DoubleArray3D mdarray = new DoubleArray3D(); + mdarray.configure(size_i, size_j, size_k); + + final Random rand = new Random(31L); + final double[][][] data = new double[size_i][size_j][size_j]; + for (int i = 0; i < size_i; i++) { + for (int j = 0; j < size_j; j++) { + for (int k = 0; k < size_k; k++) { + double v = rand.nextDouble(); + data[i][j][k] = v; + mdarray.set(i, j, k, v); + } + } + } + + Assert.assertEquals(size_i * size_j * size_k, mdarray.getSize()); + + for (int i = 0; i < size_i; i++) { + for (int j = 0; j < size_j; j++) { + for (int k = 0; k < size_k; k++) { + Assert.assertEquals(data[i][j][k], mdarray.get(i, j, k), 0.d); + } + } + } + } + + @Test + public void testConfigureExpand() { + int size_i = 50, size_j = 50, size_k = 5; + + final DoubleArray3D mdarray = new DoubleArray3D(); + mdarray.configure(size_i, size_j, size_k); + + final Random rand = new Random(31L); + for (int i = 0; i < size_i; i++) { + for (int j = 0; j < size_j; j++) { + for (int k = 0; k < size_k; k++) { + double v = rand.nextDouble(); + mdarray.set(i, j, k, v); + } + } + } + + size_i = 101; + size_j = 101; + size_k = 11; + mdarray.configure(size_i, size_j, size_k); + Assert.assertEquals(size_i * size_j * size_k, mdarray.getCapacity()); + Assert.assertEquals(size_i * size_j * size_k, mdarray.getSize()); + + final double[][][] data = new double[size_i][size_j][size_j]; + for (int i = 0; i < size_i; i++) { + for (int j = 0; j < size_j; j++) { + for (int k = 0; k < size_k; k++) { + double v = rand.nextDouble(); + data[i][j][k] = v; + mdarray.set(i, j, k, v); + } + } + } + + for (int i = 0; i < size_i; i++) { + for (int j = 0; j < size_j; j++) { + for (int k = 0; k < size_k; k++) { + Assert.assertEquals(data[i][j][k], mdarray.get(i, j, k), 0.d); + } + } + } + } + + @Test + public void testConfigureShrink() { + int size_i = 50, size_j = 50, size_k = 5; + + final DoubleArray3D mdarray = new DoubleArray3D(); + mdarray.configure(size_i, size_j, size_k); + + final Random rand = new Random(31L); + for (int i = 0; i < size_i; i++) { + for (int j = 0; j < size_j; j++) { + for (int k = 0; k < size_k; k++) { + double v = rand.nextDouble(); + mdarray.set(i, j, k, v); + } + } + } + + int capacity = mdarray.getCapacity(); + size_i = 49; + size_j = 49; + size_k = 4; + mdarray.configure(size_i, size_j, size_k); + Assert.assertEquals(capacity, mdarray.getCapacity()); + Assert.assertEquals(size_i * size_j * size_k, mdarray.getSize()); + + final double[][][] data = new double[size_i][size_j][size_j]; + for (int i = 0; i < size_i; i++) { + for (int j = 0; j < size_j; j++) { + for (int k = 0; k < size_k; k++) { + double v = rand.nextDouble(); + data[i][j][k] = v; + mdarray.set(i, j, k, v); + } + } + } + + for (int i = 0; i < size_i; i++) { + for (int j = 0; j < size_j; j++) { + for (int k = 0; k < size_k; k++) { + Assert.assertEquals(data[i][j][k], mdarray.get(i, j, k), 0.d); + } + } + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/arrays/DoubleArrayTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/arrays/DoubleArrayTest.java b/core/src/test/java/hivemall/utils/collections/arrays/DoubleArrayTest.java new file mode 100644 index 0000000..ab52717 --- /dev/null +++ b/core/src/test/java/hivemall/utils/collections/arrays/DoubleArrayTest.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.collections.arrays; + +import hivemall.utils.collections.arrays.SparseDoubleArray; + +import org.junit.Assert; +import org.junit.Test; + +public class DoubleArrayTest { + + @Test + public void testSparseDoubleArrayToArray() { + SparseDoubleArray array = new SparseDoubleArray(3); + for (int i = 0; i < 10; i++) { + array.put(i, 10 + i); + } + Assert.assertEquals(10, array.size()); + Assert.assertEquals(10, array.toArray(false).length); + + double[] copied = array.toArray(true); + Assert.assertEquals(10, copied.length); + for (int i = 0; i < 10; i++) { + Assert.assertEquals(10 + i, copied[i], 0.d); + } + } + + @Test + public void testSparseDoubleArrayClear() { + SparseDoubleArray array = new SparseDoubleArray(3); + for (int i = 0; i < 10; i++) { + array.put(i, 10 + i); + } + array.clear(); + Assert.assertEquals(0, array.size()); + Assert.assertEquals(0, array.get(0), 0.d); + for (int i = 0; i < 5; i++) { + array.put(i, 100 + i); + } + Assert.assertEquals(5, array.size()); + for (int i = 0; i < 5; i++) { + Assert.assertEquals(100 + i, array.get(i), 0.d); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/arrays/IntArrayTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/arrays/IntArrayTest.java b/core/src/test/java/hivemall/utils/collections/arrays/IntArrayTest.java new file mode 100644 index 0000000..0ce3912 --- /dev/null +++ b/core/src/test/java/hivemall/utils/collections/arrays/IntArrayTest.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.collections.arrays; + +import hivemall.utils.collections.arrays.DenseIntArray; +import hivemall.utils.collections.arrays.SparseIntArray; + +import org.junit.Assert; +import org.junit.Test; + +public class IntArrayTest { + + @Test + public void testFixedIntArrayToArray() { + DenseIntArray array = new DenseIntArray(11); + for (int i = 0; i < 10; i++) { + array.put(i, 10 + i); + } + Assert.assertEquals(11, array.size()); + Assert.assertEquals(11, array.toArray(false).length); + + int[] copied = array.toArray(true); + Assert.assertEquals(11, copied.length); + for (int i = 0; i < 10; i++) { + Assert.assertEquals(10 + i, copied[i]); + } + } + + @Test + public void testSparseIntArrayToArray() { + SparseIntArray array = new SparseIntArray(3); + for (int i = 0; i < 10; i++) { + array.put(i, 10 + i); + } + Assert.assertEquals(10, array.size()); + Assert.assertEquals(10, array.toArray(false).length); + + int[] copied = array.toArray(true); + Assert.assertEquals(10, copied.length); + for (int i = 0; i < 10; i++) { + Assert.assertEquals(10 + i, copied[i]); + } + } + + @Test + public void testSparseIntArrayClear() { + SparseIntArray array = new SparseIntArray(3); + for (int i = 0; i < 10; i++) { + array.put(i, 10 + i); + } + array.clear(); + Assert.assertEquals(0, array.size()); + Assert.assertEquals(0, array.get(0)); + for (int i = 0; i < 5; i++) { + array.put(i, 100 + i); + } + Assert.assertEquals(5, array.size()); + for (int i = 0; i < 5; i++) { + Assert.assertEquals(100 + i, array.get(i)); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/arrays/SparseIntArrayTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/arrays/SparseIntArrayTest.java b/core/src/test/java/hivemall/utils/collections/arrays/SparseIntArrayTest.java new file mode 100644 index 0000000..db3c8eb --- /dev/null +++ b/core/src/test/java/hivemall/utils/collections/arrays/SparseIntArrayTest.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.collections.arrays; + +import hivemall.utils.collections.arrays.IntArray; +import hivemall.utils.collections.arrays.SparseIntArray; + +import java.util.Random; + +import org.junit.Assert; +import org.junit.Test; + +public class SparseIntArrayTest { + + @Test + public void testDense() { + int size = 1000; + Random rand = new Random(31); + int[] expected = new int[size]; + IntArray actual = new SparseIntArray(10); + for (int i = 0; i < size; i++) { + int r = rand.nextInt(size); + expected[i] = r; + actual.put(i, r); + } + for (int i = 0; i < size; i++) { + Assert.assertEquals(expected[i], actual.get(i)); + } + } + + @Test + public void testSparse() { + int size = 1000; + Random rand = new Random(31); + int[] expected = new int[size]; + SparseIntArray actual = new SparseIntArray(10); + for (int i = 0; i < size; i++) { + int key = rand.nextInt(size); + int v = rand.nextInt(); + expected[key] = v; + actual.put(key, v); + } + for (int i = 0; i < actual.size(); i++) { + int key = actual.keyAt(i); + Assert.assertEquals(expected[key], actual.get(key, 0)); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/lists/LongArrayListTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/lists/LongArrayListTest.java b/core/src/test/java/hivemall/utils/collections/lists/LongArrayListTest.java new file mode 100644 index 0000000..c40ea7e --- /dev/null +++ b/core/src/test/java/hivemall/utils/collections/lists/LongArrayListTest.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.collections.lists; + + +import org.junit.Assert; +import org.junit.Test; + +public class LongArrayListTest { + + @Test + public void testRemoveIndex() { + LongArrayList list = new LongArrayList(); + list.add(0).add(1).add(2).add(3); + Assert.assertEquals(1, list.remove(1)); + Assert.assertEquals(3, list.size()); + Assert.assertArrayEquals(new long[] {0, 2, 3}, list.toArray()); + Assert.assertEquals(3, list.remove(2)); + Assert.assertArrayEquals(new long[] {0, 2}, list.toArray()); + Assert.assertEquals(0, list.remove(0)); + Assert.assertArrayEquals(new long[] {2}, list.toArray()); + list.add(0).add(1); + Assert.assertEquals(3, list.size()); + Assert.assertArrayEquals(new long[] {2, 0, 1}, list.toArray()); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/utils/collections/maps/Int2FloatOpenHashMapTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/maps/Int2FloatOpenHashMapTest.java b/core/src/test/java/hivemall/utils/collections/maps/Int2FloatOpenHashMapTest.java new file mode 100644 index 0000000..6a2ff96 --- /dev/null +++ b/core/src/test/java/hivemall/utils/collections/maps/Int2FloatOpenHashMapTest.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.collections.maps; + +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; + +import org.junit.Assert; +import org.junit.Test; + +public class Int2FloatOpenHashMapTest { + + @Test + public void testSize() { + Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(16384); + map.put(1, 3.f); + Assert.assertEquals(3.f, map.get(1), 0.d); + map.put(1, 5.f); + Assert.assertEquals(5.f, map.get(1), 0.d); + Assert.assertEquals(1, map.size()); + } + + @Test + public void testDefaultReturnValue() { + Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(16384); + Assert.assertEquals(0, map.size()); + Assert.assertEquals(-1.f, map.get(1), 0.d); + float ret = Float.MIN_VALUE; + map.defaultReturnValue(ret); + Assert.assertEquals(ret, map.get(1), 0.d); + } + + @Test + public void testPutAndGet() { + Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(16384); + final int numEntries = 1000000; + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(-1.f, map.put(i, Float.valueOf(i + 0.1f)), 0.d); + } + Assert.assertEquals(numEntries, map.size()); + for (int i = 0; i < numEntries; i++) { + Float v = map.get(i); + Assert.assertEquals(i + 0.1f, v.floatValue(), 0.d); + } + } + + @Test + public void testIterator() { + Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(1000); + Int2FloatOpenHashTable.IMapIterator itor = map.entries(); + Assert.assertFalse(itor.hasNext()); + + final int numEntries = 1000000; + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(-1.f, map.put(i, Float.valueOf(i + 0.1f)), 0.d); + } + Assert.assertEquals(numEntries, map.size()); + + itor = map.entries(); + Assert.assertTrue(itor.hasNext()); + while (itor.hasNext()) { + Assert.assertFalse(itor.next() == -1); + int k = itor.getKey(); + Float v = itor.getValue(); + Assert.assertEquals(k + 0.1f, v.floatValue(), 0.d); + } + Assert.assertEquals(-1, itor.next()); + } + + @Test + public void testIterator2() { + Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(100); + map.put(33, 3.16f); + + Int2FloatOpenHashTable.IMapIterator itor = map.entries(); + Assert.assertTrue(itor.hasNext()); + Assert.assertNotEquals(-1, itor.next()); + Assert.assertEquals(33, itor.getKey()); + Assert.assertEquals(3.16f, itor.getValue(), 0.d); + Assert.assertEquals(-1, itor.next()); + } + +}