http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java new file mode 100644 index 0000000..03e3198 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.util; + +import java.io.FileInputStream; +import java.io.FileWriter; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Random; +import java.util.stream.Stream; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; + +/** + * Utility class for reading MNIST dataset. + */ +public class MnistUtils { + /** + * Read random {@code count} samples from MNIST dataset from two files (images and labels) into a stream of labeled vectors. + * @param imagesPath Path to the file with images. + * @param labelsPath Path to the file with labels. + * @param rnd Random numbers generatror. + * @param count Count of samples to read. + * @return Stream of MNIST samples. + * @throws IOException + */ + public static Stream<DenseLocalOnHeapVector> mnist(String imagesPath, String labelsPath, Random rnd, int count) throws IOException { + FileInputStream isImages = new FileInputStream(imagesPath); + FileInputStream isLabels = new FileInputStream(labelsPath); + + int magic = read4Bytes(isImages); // Skip magic number. + int numOfImages = read4Bytes(isImages); + int imgHeight = read4Bytes(isImages); + int imgWidth = read4Bytes(isImages); + + read4Bytes(isLabels); // Skip magic number. + read4Bytes(isLabels); // Skip number of labels. + + int numOfPixels = imgHeight * imgWidth; + + System.out.println("Magic: " + magic); + System.out.println("Num of images: " + numOfImages); + System.out.println("Num of pixels: " + numOfPixels); + + double[][] vecs = new double[numOfImages][numOfPixels + 1]; + + for (int imgNum = 0; imgNum < numOfImages; imgNum++) { + vecs[imgNum][numOfPixels] = isLabels.read(); + for (int p = 0; p < numOfPixels; p++) { + int c = 128 - isImages.read(); + vecs[imgNum][p] = (double)c / 128; + } + } + + List<double[]> lst = Arrays.asList(vecs); + Collections.shuffle(lst, rnd); + + isImages.close(); + isLabels.close(); + + return lst.subList(0, count).stream().map(DenseLocalOnHeapVector::new); + } + + /** + * Convert random {@code count} samples from MNIST dataset from two files (images and labels) into libsvm format. + * @param imagesPath Path to the file with images. + * @param labelsPath Path to the file with labels. + * @param outPath Path to output path. + * @param rnd Random numbers generator. + * @param count Count of samples to read. + * @throws IOException + */ + public static void asLIBSVM(String imagesPath, String labelsPath, String outPath, Random rnd, int count) throws IOException { + + try (FileWriter fos = new FileWriter(outPath)) { + mnist(imagesPath, labelsPath, rnd, count).forEach(vec -> { + try { + fos.write((int)vec.get(vec.size() - 1) + " "); + + for (int i = 0; i < vec.size() - 1; i++) { + double val = vec.get(i); + + if (val != 0) + fos.write((i + 1) + ":" + val + " "); + } + + fos.write("\n"); + + } + catch (IOException e) { + e.printStackTrace(); + } + }); + } + } + + /** + * Utility method for reading 4 bytes from input stream. + * @param is Input stream. + * @throws IOException + */ + private static int read4Bytes(FileInputStream is) throws IOException { + return (is.read() << 24) | (is.read() << 16) | (is.read() << 8) | (is.read()); + } +}
http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java b/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java new file mode 100644 index 0000000..b7669be --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.util; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; + +/** + * Class with various utility methods. + */ +public class Utils { + /** + * Perform deep copy of an object. + * @param orig Original object. + * @param <T> Class of original object; + * @return Deep copy of original object. + */ + public static <T> T copy(T orig) { + Object obj = null; + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ObjectOutputStream out = new ObjectOutputStream(baos); + out.writeObject(orig); + out.flush(); + out.close(); + ObjectInputStream in = new ObjectInputStream( + new ByteArrayInputStream(baos.toByteArray())); + obj = in.readObject(); + } + catch (IOException | ClassNotFoundException e) { + e.printStackTrace(); + } + return (T)obj; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java index 5ac7443..47910c8 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java @@ -20,6 +20,7 @@ package org.apache.ignite.ml; import org.apache.ignite.ml.clustering.ClusteringTestSuite; import org.apache.ignite.ml.math.MathImplMainTestSuite; import org.apache.ignite.ml.regressions.RegressionsTestSuite; +import org.apache.ignite.ml.trees.DecisionTreesTestSuite; import org.junit.runner.RunWith; import org.junit.runners.Suite; @@ -30,7 +31,8 @@ import org.junit.runners.Suite; @Suite.SuiteClasses({ MathImplMainTestSuite.class, RegressionsTestSuite.class, - ClusteringTestSuite.class + ClusteringTestSuite.class, + DecisionTreesTestSuite.class }) public class IgniteMLTestSuite { // No-op. http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java b/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java index 62fdf2c..d094813 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java @@ -23,6 +23,8 @@ import org.apache.ignite.ml.math.Precision; import org.apache.ignite.ml.math.Vector; import org.junit.Assert; +import static org.junit.Assert.assertTrue; + /** */ public class TestUtils { /** @@ -245,4 +247,17 @@ public class TestUtils { public static double maximumAbsoluteRowSum(Matrix mtx) { return IntStream.range(0, mtx.rowSize()).mapToObj(mtx::viewRow).map(v -> Math.abs(v.sum())).reduce(Math::max).get(); } + + /** */ + public static void checkIsInEpsilonNeighbourhood(Vector[] v1s, Vector[] v2s, double epsilon) { + for (int i = 0; i < v1s.length; i++) { + assertTrue("Not in epsilon neighbourhood (index " + i + ") ", + v1s[i].minus(v2s[i]).kNorm(2) < epsilon); + } + } + + /** */ + public static void checkIsInEpsilonNeighbourhood(Vector v1, Vector v2, double epsilon) { + checkIsInEpsilonNeighbourhood(new Vector[] {v1}, new Vector[] {v2}, epsilon); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/SparseDistributedBlockMatrixTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/SparseDistributedBlockMatrixTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/SparseDistributedBlockMatrixTest.java index 2943bc0..fd6ed78 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/SparseDistributedBlockMatrixTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/SparseDistributedBlockMatrixTest.java @@ -24,6 +24,7 @@ import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.util.Collection; import java.util.Set; +import java.util.UUID; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.internal.util.IgniteUtils; http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/test/java/org/apache/ignite/ml/trees/BaseDecisionTreeTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trees/BaseDecisionTreeTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trees/BaseDecisionTreeTest.java new file mode 100644 index 0000000..65f0ae4 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/trees/BaseDecisionTreeTest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees; + +import java.util.Arrays; +import org.apache.ignite.Ignite; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.structures.LabeledVectorDouble; +import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; + +/** + * Base class for decision trees test. + */ +public class BaseDecisionTreeTest extends GridCommonAbstractTest { + /** Count of nodes. */ + private static final int NODE_COUNT = 4; + + /** Grid instance. */ + protected Ignite ignite; + + /** + * Default constructor. + */ + public BaseDecisionTreeTest() { + super(false); + } + + /** + * {@inheritDoc} + */ + @Override protected void beforeTest() throws Exception { + ignite = grid(NODE_COUNT); + } + + /** {@inheritDoc} */ + @Override protected void beforeTestsStarted() throws Exception { + for (int i = 1; i <= NODE_COUNT; i++) + startGrid(i); + } + + /** {@inheritDoc} */ + @Override protected void afterTestsStopped() throws Exception { + stopAllGrids(); + } + + /** + * Convert double array to {@link LabeledVectorDouble} + * + * @param arr Array for conversion. + * @return LabeledVectorDouble. + */ + protected static LabeledVectorDouble<DenseLocalOnHeapVector> asLabeledVector(double arr[]) { + return new LabeledVectorDouble<>(new DenseLocalOnHeapVector(Arrays.copyOf(arr, arr.length - 1)), arr[arr.length - 1]); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java new file mode 100644 index 0000000..2b03b47 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees; + +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.stream.Collectors; +import java.util.stream.DoubleStream; +import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.lang.IgniteBiTuple; +import org.apache.ignite.ml.math.StorageConstants; +import org.apache.ignite.ml.math.Tracer; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.structures.LabeledVectorDouble; +import org.apache.ignite.ml.trees.models.DecisionTreeModel; +import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer; +import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainerInput; +import org.apache.ignite.ml.trees.trainers.columnbased.MatrixColumnDecisionTreeTrainerInput; +import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.ContinuousSplitCalculators; +import org.apache.ignite.ml.trees.trainers.columnbased.regcalcs.RegionCalculators; + +/** Tests behaviour of ColumnDecisionTreeTrainer. */ +public class ColumnDecisionTreeTrainerTest extends BaseDecisionTreeTest { + /** + * Test {@link ColumnDecisionTreeTrainerTest} for mixed (continuous and categorical) data with Gini impurity. + */ + public void testCacheMixedGini() { + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + int totalPts = 1 << 10; + int featCnt = 2; + + HashMap<Integer, Integer> catsInfo = new HashMap<>(); + catsInfo.put(1, 3); + + Random rnd = new Random(12349L); + + SplitDataGenerator<DenseLocalOnHeapVector> gen = new SplitDataGenerator<>( + featCnt, catsInfo, () -> new DenseLocalOnHeapVector(featCnt + 1), rnd). + split(0, 1, new int[] {0, 2}). + split(1, 0, -10.0); + + testByGen(totalPts, catsInfo, gen, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MEAN, rnd); + } + + /** + * Test {@link ColumnDecisionTreeTrainerTest} for mixed (continuous and categorical) data with Variance impurity. + */ + public void testCacheMixed() { + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + int totalPts = 1 << 10; + int featCnt = 2; + + HashMap<Integer, Integer> catsInfo = new HashMap<>(); + catsInfo.put(1, 3); + + Random rnd = new Random(12349L); + + SplitDataGenerator<DenseLocalOnHeapVector> gen = new SplitDataGenerator<>( + featCnt, catsInfo, () -> new DenseLocalOnHeapVector(featCnt + 1), rnd). + split(0, 1, new int[] {0, 2}). + split(1, 0, -10.0); + + testByGen(totalPts, catsInfo, gen, ContinuousSplitCalculators.VARIANCE, RegionCalculators.VARIANCE, RegionCalculators.MEAN, rnd); + } + + /** + * Test {@link ColumnDecisionTreeTrainerTest} for continuous data with Variance impurity. + */ + public void testCacheCont() { + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + int totalPts = 1 << 10; + int featCnt = 12; + + HashMap<Integer, Integer> catsInfo = new HashMap<>(); + + Random rnd = new Random(12349L); + + SplitDataGenerator<DenseLocalOnHeapVector> gen = new SplitDataGenerator<>( + featCnt, catsInfo, () -> new DenseLocalOnHeapVector(featCnt + 1), rnd). + split(0, 0, -10.0). + split(1, 0, 0.0). + split(1, 1, 2.0). + split(3, 7, 50.0); + + testByGen(totalPts, catsInfo, gen, ContinuousSplitCalculators.VARIANCE, RegionCalculators.VARIANCE, RegionCalculators.MEAN, rnd); + } + + /** + * Test {@link ColumnDecisionTreeTrainerTest} for continuous data with Gini impurity. + */ + public void testCacheContGini() { + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + int totalPts = 1 << 10; + int featCnt = 12; + + HashMap<Integer, Integer> catsInfo = new HashMap<>(); + + Random rnd = new Random(12349L); + + SplitDataGenerator<DenseLocalOnHeapVector> gen = new SplitDataGenerator<>( + featCnt, catsInfo, () -> new DenseLocalOnHeapVector(featCnt + 1), rnd). + split(0, 0, -10.0). + split(1, 0, 0.0). + split(1, 1, 2.0). + split(3, 7, 50.0); + + testByGen(totalPts, catsInfo, gen, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MEAN, rnd); + } + + /** + * Test {@link ColumnDecisionTreeTrainerTest} for categorical data with Variance impurity. + */ + public void testCacheCat() { + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + int totalPts = 1 << 10; + int featCnt = 12; + + HashMap<Integer, Integer> catsInfo = new HashMap<>(); + catsInfo.put(5, 7); + + Random rnd = new Random(12349L); + + SplitDataGenerator<DenseLocalOnHeapVector> gen = new SplitDataGenerator<>( + featCnt, catsInfo, () -> new DenseLocalOnHeapVector(featCnt + 1), rnd). + split(0, 5, new int[] {0, 2, 5}); + + testByGen(totalPts, catsInfo, gen, ContinuousSplitCalculators.VARIANCE, RegionCalculators.VARIANCE, RegionCalculators.MEAN, rnd); + } + + /** */ + private <D extends ContinuousRegionInfo> void testByGen(int totalPts, HashMap<Integer, Integer> catsInfo, + SplitDataGenerator<DenseLocalOnHeapVector> gen, + IgniteFunction<ColumnDecisionTreeTrainerInput, ? extends ContinuousSplitCalculator<D>> calc, + IgniteFunction<ColumnDecisionTreeTrainerInput, IgniteFunction<DoubleStream, Double>> catImpCalc, + IgniteFunction<DoubleStream, Double> regCalc, Random rnd) { + + List<IgniteBiTuple<Integer, DenseLocalOnHeapVector>> lst = gen. + points(totalPts, (i, rn) -> i). + collect(Collectors.toList()); + + int featCnt = gen.featuresCnt(); + + Collections.shuffle(lst, rnd); + + SparseDistributedMatrix m = new SparseDistributedMatrix(totalPts, featCnt + 1, StorageConstants.COLUMN_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE); + + Map<Integer, List<LabeledVectorDouble>> byRegion = new HashMap<>(); + + int i = 0; + for (IgniteBiTuple<Integer, DenseLocalOnHeapVector> bt : lst) { + byRegion.putIfAbsent(bt.get1(), new LinkedList<>()); + byRegion.get(bt.get1()).add(asLabeledVector(bt.get2().getStorage().data())); + m.setRow(i, bt.get2().getStorage().data()); + i++; + } + + ColumnDecisionTreeTrainer<D> trainer = + new ColumnDecisionTreeTrainer<>(3, calc, catImpCalc, regCalc, ignite); + + DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, catsInfo)); + + byRegion.keySet().forEach(k -> { + LabeledVectorDouble sp = byRegion.get(k).get(0); + Tracer.showAscii(sp.vector()); + System.out.println("Act: " + sp.label() + " " + " pred: " + mdl.predict(sp.vector())); + assert mdl.predict(sp.vector()) == sp.doubleLabel(); + }); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/test/java/org/apache/ignite/ml/trees/DecisionTreesTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trees/DecisionTreesTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/trees/DecisionTreesTestSuite.java new file mode 100644 index 0000000..3343503 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/trees/DecisionTreesTestSuite.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; + +/** + * Test suite for all tests located in org.apache.ignite.ml.trees package + */ +@RunWith(Suite.class) [email protected]({ + ColumnDecisionTreeTrainerTest.class, + GiniSplitCalculatorTest.class, + VarianceSplitCalculatorTest.class +}) +public class DecisionTreesTestSuite { +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/test/java/org/apache/ignite/ml/trees/GiniSplitCalculatorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trees/GiniSplitCalculatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trees/GiniSplitCalculatorTest.java new file mode 100644 index 0000000..c92b4f5 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/trees/GiniSplitCalculatorTest.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees; + +import java.util.stream.DoubleStream; +import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator; +import org.apache.ignite.ml.trees.trainers.columnbased.vectors.SplitInfo; +import org.junit.Test; + +/** + * Test of {@link GiniSplitCalculator}. + */ +public class GiniSplitCalculatorTest { + /** Test calculation of region info consisting from one point. */ + @Test + public void testCalculateRegionInfoSimple() { + double labels[] = new double[] {0.0}; + + assert new GiniSplitCalculator(labels).calculateRegionInfo(DoubleStream.of(labels), 0).impurity() == 0.0; + } + + /** Test calculation of region info consisting from two distinct classes. */ + @Test + public void testCalculateRegionInfoTwoClasses() { + double labels[] = new double[] {0.0, 1.0}; + + assert new GiniSplitCalculator(labels).calculateRegionInfo(DoubleStream.of(labels), 0).impurity() == 0.5; + } + + /** Test calculation of region info consisting from three distinct classes. */ + @Test + public void testCalculateRegionInfoThreeClasses() { + double labels[] = new double[] {0.0, 1.0, 2.0}; + + assert Math.abs(new GiniSplitCalculator(labels).calculateRegionInfo(DoubleStream.of(labels), 0).impurity() - 2.0 / 3) < 1E-5; + } + + /** Test calculation of split of region consisting from one point. */ + @Test + public void testSplitSimple() { + double labels[] = new double[] {0.0}; + double values[] = new double[] {0.0}; + Integer[] samples = new Integer[] {0}; + + int cnts[] = new int[] {1}; + + GiniSplitCalculator.GiniData data = new GiniSplitCalculator.GiniData(0.0, 1, cnts, 1); + + assert new GiniSplitCalculator(labels).splitRegion(samples, values, labels, 0, data) == null; + } + + /** Test calculation of split of region consisting from two points. */ + @Test + public void testSplitTwoClassesTwoPoints() { + double labels[] = new double[] {0.0, 1.0}; + double values[] = new double[] {0.0, 1.0}; + Integer[] samples = new Integer[] {0, 1}; + + int cnts[] = new int[] {1, 1}; + + GiniSplitCalculator.GiniData data = new GiniSplitCalculator.GiniData(0.5, 2, cnts, 1.0 * 1.0 + 1.0 * 1.0); + + SplitInfo<GiniSplitCalculator.GiniData> split = new GiniSplitCalculator(labels).splitRegion(samples, values, labels, 0, data); + + assert split.leftData().impurity() == 0; + assert split.leftData().counts()[0] == 1; + assert split.leftData().counts()[1] == 0; + assert split.leftData().getSize() == 1; + + assert split.rightData().impurity() == 0; + assert split.rightData().counts()[0] == 0; + assert split.rightData().counts()[1] == 1; + assert split.rightData().getSize() == 1; + } + + /** Test calculation of split of region consisting from four distinct values. */ + @Test + public void testSplitTwoClassesFourPoints() { + double labels[] = new double[] {0.0, 0.0, 1.0, 1.0}; + double values[] = new double[] {0.0, 1.0, 2.0, 3.0}; + + Integer[] samples = new Integer[] {0, 1, 2, 3}; + + int[] cnts = new int[] {2, 2}; + + GiniSplitCalculator.GiniData data = new GiniSplitCalculator.GiniData(0.5, 4, cnts, 2.0 * 2.0 + 2.0 * 2.0); + + SplitInfo<GiniSplitCalculator.GiniData> split = new GiniSplitCalculator(labels).splitRegion(samples, values, labels, 0, data); + + assert split.leftData().impurity() == 0; + assert split.leftData().counts()[0] == 2; + assert split.leftData().counts()[1] == 0; + assert split.leftData().getSize() == 2; + + assert split.rightData().impurity() == 0; + assert split.rightData().counts()[0] == 0; + assert split.rightData().counts()[1] == 2; + assert split.rightData().getSize() == 2; + } + + /** Test calculation of split of region consisting from three distinct values. */ + @Test + public void testSplitThreePoints() { + double labels[] = new double[] {0.0, 1.0, 2.0}; + double values[] = new double[] {0.0, 1.0, 2.0}; + Integer[] samples = new Integer[] {0, 1, 2}; + + int[] cnts = new int[] {1, 1, 1}; + + GiniSplitCalculator.GiniData data = new GiniSplitCalculator.GiniData(2.0 / 3, 3, cnts, 1.0 * 1.0 + 1.0 * 1.0 + 1.0 * 1.0); + + SplitInfo<GiniSplitCalculator.GiniData> split = new GiniSplitCalculator(labels).splitRegion(samples, values, labels, 0, data); + + assert split.leftData().impurity() == 0.0; + assert split.leftData().counts()[0] == 1; + assert split.leftData().counts()[1] == 0; + assert split.leftData().counts()[2] == 0; + assert split.leftData().getSize() == 1; + + assert split.rightData().impurity() == 0.5; + assert split.rightData().counts()[0] == 0; + assert split.rightData().counts()[1] == 1; + assert split.rightData().counts()[2] == 1; + assert split.rightData().getSize() == 2; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/test/java/org/apache/ignite/ml/trees/SplitDataGenerator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trees/SplitDataGenerator.java b/modules/ml/src/test/java/org/apache/ignite/ml/trees/SplitDataGenerator.java new file mode 100644 index 0000000..279e685 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/trees/SplitDataGenerator.java @@ -0,0 +1,390 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ignite.ml.trees; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.BitSet; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import org.apache.ignite.lang.IgniteBiTuple; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException; +import org.apache.ignite.ml.util.Utils; + +/** + * Utility class for generating data which has binary tree split structure. + * + * @param <V> + */ +public class SplitDataGenerator<V extends Vector> { + /** */ + private static final double DELTA = 100.0; + + /** Map of the form of (is categorical -> list of region indexes). */ + private final Map<Boolean, List<Integer>> di; + + /** List of regions. */ + private final List<Region> regs; + + /** Data of bounds of regions. */ + private final Map<Integer, IgniteBiTuple<Double, Double>> boundsData; + + /** Random numbers generator. */ + private final Random rnd; + + /** Supplier of vectors. */ + private final Supplier<V> supplier; + + /** Features count. */ + private final int featCnt; + + /** + * Create SplitDataGenerator. + * + * @param featCnt Features count. + * @param catFeaturesInfo Information about categorical features in form of map (feature index -> categories + * count). + * @param supplier Supplier of vectors. + * @param rnd Random numbers generator. + */ + public SplitDataGenerator(int featCnt, Map<Integer, Integer> catFeaturesInfo, Supplier<V> supplier, Random rnd) { + regs = new LinkedList<>(); + boundsData = new HashMap<>(); + this.rnd = rnd; + this.supplier = supplier; + this.featCnt = featCnt; + + // Divide indexes into indexes of categorical coordinates and indexes of continuous coordinates. + di = IntStream.range(0, featCnt). + boxed(). + collect(Collectors.partitioningBy(catFeaturesInfo::containsKey)); + + // Categorical coordinates info. + Map<Integer, CatCoordInfo> catCoords = new HashMap<>(); + di.get(true).forEach(i -> { + BitSet bs = new BitSet(); + bs.set(0, catFeaturesInfo.get(i)); + catCoords.put(i, new CatCoordInfo(bs)); + }); + + // Continuous coordinates info. + Map<Integer, ContCoordInfo> contCoords = new HashMap<>(); + di.get(false).forEach(i -> { + contCoords.put(i, new ContCoordInfo()); + boundsData.put(i, new IgniteBiTuple<>(-1.0, 1.0)); + }); + + Region firstReg = new Region(catCoords, contCoords, 0); + regs.add(firstReg); + } + + /** + * Categorical coordinate info. + */ + private static class CatCoordInfo implements Serializable { + /** + * Defines categories which are included in this region + */ + private final BitSet bs; + + /** + * Construct CatCoordInfo. + * + * @param bs Bitset. + */ + CatCoordInfo(BitSet bs) { + this.bs = bs; + } + + /** {@inheritDoc} */ + @Override public String toString() { + return "CatCoordInfo [" + + "bs=" + bs + + ']'; + } + } + + /** + * Continuous coordinate info. + */ + private static class ContCoordInfo implements Serializable { + /** + * Left (min) bound of region. + */ + private double left; + + /** + * Right (max) bound of region. + */ + private double right; + + /** + * Construct ContCoordInfo. + */ + ContCoordInfo() { + left = Double.NEGATIVE_INFINITY; + right = Double.POSITIVE_INFINITY; + } + + /** {@inheritDoc} */ + @Override public String toString() { + return "ContCoordInfo [" + + "left=" + left + + ", right=" + right + + ']'; + } + } + + /** + * Class representing information about region. + */ + private static class Region implements Serializable { + /** + * Information about categorical coordinates restrictions of this region in form of + * (coordinate index -> restriction) + */ + private final Map<Integer, CatCoordInfo> catCoords; + + /** + * Information about continuous coordinates restrictions of this region in form of + * (coordinate index -> restriction) + */ + private final Map<Integer, ContCoordInfo> contCoords; + + /** + * Region should contain {@code 1/2^twoPow * totalPoints} points. + */ + private int twoPow; + + /** + * Construct region by information about restrictions on coordinates (features) values. + * + * @param catCoords Restrictions on categorical coordinates. + * @param contCoords Restrictions on continuous coordinates + * @param twoPow Region should contain {@code 1/2^twoPow * totalPoints} points. + */ + Region(Map<Integer, CatCoordInfo> catCoords, Map<Integer, ContCoordInfo> contCoords, int twoPow) { + this.catCoords = catCoords; + this.contCoords = contCoords; + this.twoPow = twoPow; + } + + /** */ + int divideBy() { + return 1 << twoPow; + } + + /** */ + void incTwoPow() { + twoPow++; + } + + /** {@inheritDoc} */ + @Override public String toString() { + return "Region [" + + "catCoords=" + catCoords + + ", contCoords=" + contCoords + + ", twoPow=" + twoPow + + ']'; + } + + /** + * Generate continuous coordinate for this region. + * + * @param coordIdx Coordinate index. + * @param boundsData Data with bounds + * @param rnd Random numbers generator. + * @return Categorical coordinate value. + */ + double generateContCoord(int coordIdx, Map<Integer, IgniteBiTuple<Double, Double>> boundsData, + Random rnd) { + ContCoordInfo cci = contCoords.get(coordIdx); + double left = cci.left; + double right = cci.right; + + if (left == Double.NEGATIVE_INFINITY) + left = boundsData.get(coordIdx).get1() - DELTA; + + if (right == Double.POSITIVE_INFINITY) + right = boundsData.get(coordIdx).get2() + DELTA; + + double size = right - left; + + return left + rnd.nextDouble() * size; + } + + /** + * Generate categorical coordinate value for this region. + * + * @param coordIdx Coordinate index. + * @param rnd Random numbers generator. + * @return Categorical coordinate value. + */ + double generateCatCoord(int coordIdx, Random rnd) { + // Pick random bit. + BitSet bs = catCoords.get(coordIdx).bs; + int j = rnd.nextInt(bs.length()); + + int i = 0; + int bn = 0; + int bnp = 0; + + while ((bn = bs.nextSetBit(bn)) != -1 && i <= j) { + i++; + bnp = bn; + bn++; + } + + return bnp; + } + + /** + * Generate points for this region. + * + * @param ptsCnt Count of points to generate. + * @param val Label for all points in this region. + * @param boundsData Data about bounds of continuous coordinates. + * @param catCont Data about which categories can be in this region in the form (coordinate index -> list of + * categories indexes). + * @param s Vectors supplier. + * @param rnd Random numbers generator. + * @param <V> Type of vectors. + * @return Stream of generated points for this region. + */ + <V extends Vector> Stream<V> generatePoints(int ptsCnt, double val, + Map<Integer, IgniteBiTuple<Double, Double>> boundsData, Map<Boolean, List<Integer>> catCont, + Supplier<V> s, + Random rnd) { + return IntStream.range(0, ptsCnt / divideBy()).mapToObj(i -> { + V v = s.get(); + int coordsCnt = v.size(); + catCont.get(false).forEach(ci -> v.setX(ci, generateContCoord(ci, boundsData, rnd))); + catCont.get(true).forEach(ci -> v.setX(ci, generateCatCoord(ci, rnd))); + + v.setX(coordsCnt - 1, val); + return v; + }); + } + } + + /** + * Split region by continuous coordinate.using given threshold. + * + * @param regIdx Region index. + * @param coordIdx Coordinate index. + * @param threshold Threshold. + * @return {@code this}. + */ + public SplitDataGenerator<V> split(int regIdx, int coordIdx, double threshold) { + Region regToSplit = regs.get(regIdx); + ContCoordInfo cci = regToSplit.contCoords.get(coordIdx); + + double left = cci.left; + double right = cci.right; + + if (threshold < left || threshold > right) + throw new MathIllegalArgumentException("Threshold is out of region bounds."); + + regToSplit.incTwoPow(); + + Region newReg = Utils.copy(regToSplit); + newReg.contCoords.get(coordIdx).left = threshold; + + regs.add(regIdx + 1, newReg); + cci.right = threshold; + + IgniteBiTuple<Double, Double> bounds = boundsData.get(coordIdx); + double min = bounds.get1(); + double max = bounds.get2(); + boundsData.put(coordIdx, new IgniteBiTuple<>(Math.min(threshold, min), Math.max(max, threshold))); + + return this; + } + + /** + * Split region by categorical coordinate. + * + * @param regIdx Region index. + * @param coordIdx Coordinate index. + * @param cats Categories allowed for the left sub region. + * @return {@code this}. + */ + public SplitDataGenerator<V> split(int regIdx, int coordIdx, int[] cats) { + BitSet subset = new BitSet(); + Arrays.stream(cats).forEach(subset::set); + Region regToSplit = regs.get(regIdx); + CatCoordInfo cci = regToSplit.catCoords.get(coordIdx); + + BitSet ssc = (BitSet)subset.clone(); + BitSet set = cci.bs; + ssc.and(set); + if (ssc.length() != subset.length()) + throw new MathIllegalArgumentException("Splitter set is not a subset of a parent subset."); + + ssc.xor(set); + set.and(subset); + + regToSplit.incTwoPow(); + Region newReg = Utils.copy(regToSplit); + newReg.catCoords.put(coordIdx, new CatCoordInfo(ssc)); + + regs.add(regIdx + 1, newReg); + + return this; + } + + /** + * Get stream of points generated by this generator. + * + * @param ptsCnt Points count. + */ + public Stream<IgniteBiTuple<Integer, V>> points(int ptsCnt, BiFunction<Double, Random, Double> f) { + regs.forEach(System.out::println); + + return IntStream.range(0, regs.size()). + boxed(). + map(i -> regs.get(i).generatePoints(ptsCnt, f.apply((double)i, rnd), boundsData, di, supplier, rnd).map(v -> new IgniteBiTuple<>(i, v))).flatMap(Function.identity()); + } + + /** + * Count of regions. + * + * @return Count of regions. + */ + public int regsCount() { + return regs.size(); + } + + /** + * Get features count. + * + * @return Features count. + */ + public int featuresCnt() { + return featCnt; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/test/java/org/apache/ignite/ml/trees/VarianceSplitCalculatorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trees/VarianceSplitCalculatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trees/VarianceSplitCalculatorTest.java new file mode 100644 index 0000000..d67cbc6 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/trees/VarianceSplitCalculatorTest.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees; + +import java.util.stream.DoubleStream; +import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.VarianceSplitCalculator; +import org.apache.ignite.ml.trees.trainers.columnbased.vectors.SplitInfo; +import org.junit.Test; + +/** + * Test for {@link VarianceSplitCalculator}. + */ +public class VarianceSplitCalculatorTest { + /** Test calculation of region info consisting from one point. */ + @Test + public void testCalculateRegionInfoSimple() { + double labels[] = new double[] {0.0}; + + assert new VarianceSplitCalculator().calculateRegionInfo(DoubleStream.of(labels), 1).impurity() == 0.0; + } + + /** Test calculation of region info consisting from two classes. */ + @Test + public void testCalculateRegionInfoTwoClasses() { + double labels[] = new double[] {0.0, 1.0}; + + assert new VarianceSplitCalculator().calculateRegionInfo(DoubleStream.of(labels), 2).impurity() == 0.25; + } + + /** Test calculation of region info consisting from three classes. */ + @Test + public void testCalculateRegionInfoThreeClasses() { + double labels[] = new double[] {1.0, 2.0, 3.0}; + + assert Math.abs(new VarianceSplitCalculator().calculateRegionInfo(DoubleStream.of(labels), 3).impurity() - 2.0 / 3) < 1E-10; + } + + /** Test calculation of split of region consisting from one point. */ + @Test + public void testSplitSimple() { + double labels[] = new double[] {0.0}; + double values[] = new double[] {0.0}; + Integer[] samples = new Integer[] {0}; + + VarianceSplitCalculator.VarianceData data = new VarianceSplitCalculator.VarianceData(0.0, 1, 0.0); + + assert new VarianceSplitCalculator().splitRegion(samples, values, labels, 0, data) == null; + } + + /** Test calculation of split of region consisting from two classes. */ + @Test + public void testSplitTwoClassesTwoPoints() { + double labels[] = new double[] {0.0, 1.0}; + double values[] = new double[] {0.0, 1.0}; + Integer[] samples = new Integer[] {0, 1}; + + VarianceSplitCalculator.VarianceData data = new VarianceSplitCalculator.VarianceData(0.25, 2, 0.5); + + SplitInfo<VarianceSplitCalculator.VarianceData> split = new VarianceSplitCalculator().splitRegion(samples, values, labels, 0, data); + + assert split.leftData().impurity() == 0; + assert split.leftData().mean() == 0; + assert split.leftData().getSize() == 1; + + assert split.rightData().impurity() == 0; + assert split.rightData().mean() == 1; + assert split.rightData().getSize() == 1; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java b/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java new file mode 100644 index 0000000..4e7cc24 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java @@ -0,0 +1,455 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trees.performance; + +import it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap; +import java.io.IOException; +import java.io.InputStream; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Random; +import java.util.UUID; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.DoubleStream; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.IgniteDataStreamer; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.CacheAtomicityMode; +import org.apache.ignite.cache.CacheMode; +import org.apache.ignite.cache.CacheWriteSynchronizationMode; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.configuration.IgniteConfiguration; +import org.apache.ignite.internal.processors.cache.GridCacheProcessor; +import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.lang.IgniteBiTuple; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.estimators.Estimators; +import org.apache.ignite.ml.math.StorageConstants; +import org.apache.ignite.ml.math.Tracer; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.distributed.keys.impl.SparseMatrixKey; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.functions.IgniteTriFunction; +import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix; +import org.apache.ignite.ml.math.impls.storage.matrix.SparseDistributedMatrixStorage; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.structures.LabeledVectorDouble; +import org.apache.ignite.ml.trees.BaseDecisionTreeTest; +import org.apache.ignite.ml.trees.SplitDataGenerator; +import org.apache.ignite.ml.trees.models.DecisionTreeModel; +import org.apache.ignite.ml.trees.trainers.columnbased.BiIndex; +import org.apache.ignite.ml.trees.trainers.columnbased.BiIndexedCacheColumnDecisionTreeTrainerInput; +import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer; +import org.apache.ignite.ml.trees.trainers.columnbased.MatrixColumnDecisionTreeTrainerInput; +import org.apache.ignite.ml.trees.trainers.columnbased.caches.ContextCache; +import org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache; +import org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache; +import org.apache.ignite.ml.trees.trainers.columnbased.caches.SplitCache; +import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.ContinuousSplitCalculators; +import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator; +import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.VarianceSplitCalculator; +import org.apache.ignite.ml.trees.trainers.columnbased.regcalcs.RegionCalculators; +import org.apache.ignite.ml.util.MnistUtils; +import org.apache.ignite.stream.StreamTransformer; +import org.apache.ignite.testframework.junits.IgniteTestResources; +import org.apache.log4j.Level; +import org.junit.Assert; + +/** + * Various benchmarks for hand runs. + */ +public class ColumnDecisionTreeTrainerBenchmark extends BaseDecisionTreeTest { + /** Name of the property specifying path to training set images. */ + private static final String PROP_TRAINING_IMAGES = "mnist.training.images"; + + /** Name of property specifying path to training set labels. */ + private static final String PROP_TRAINING_LABELS = "mnist.training.labels"; + + /** Name of property specifying path to test set images. */ + private static final String PROP_TEST_IMAGES = "mnist.test.images"; + + /** Name of property specifying path to test set labels. */ + private static final String PROP_TEST_LABELS = "mnist.test.labels"; + + /** Function to approximate. */ + private static final Function<Vector, Double> f1 = v -> v.get(0) * v.get(0) + 2 * Math.sin(v.get(1)) + v.get(2); + + /** {@inheritDoc} */ + @Override protected long getTestTimeout() { + return 6000000; + } + + /** {@inheritDoc} */ + @Override protected IgniteConfiguration getConfiguration(String igniteInstanceName, + IgniteTestResources rsrcs) throws Exception { + IgniteConfiguration configuration = super.getConfiguration(igniteInstanceName, rsrcs); + // We do not need any extra event types. + configuration.setIncludeEventTypes(); + configuration.setPeerClassLoadingEnabled(false); + + resetLog4j(Level.INFO, false, GridCacheProcessor.class.getPackage().getName()); + + return configuration; + } + + /** + * This test is for manual run only. + * To run this test rename this method so it starts from 'test'. + */ + public void tstCacheMixed() { + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + int ptsPerReg = 150; + int featCnt = 10; + + HashMap<Integer, Integer> catsInfo = new HashMap<>(); + catsInfo.put(1, 3); + + Random rnd = new Random(12349L); + + SplitDataGenerator<DenseLocalOnHeapVector> gen = new SplitDataGenerator<>( + featCnt, catsInfo, () -> new DenseLocalOnHeapVector(featCnt + 1), rnd). + split(0, 1, new int[] {0, 2}). + split(1, 0, -10.0). + split(0, 0, 0.0); + + testByGenStreamerLoad(ptsPerReg, catsInfo, gen, rnd); + } + + /** + * Run decision tree classifier on MNIST using bi-indexed cache as a storage for dataset. + * To run this test rename this method so it starts from 'test'. + * + * @throws IOException In case of loading MNIST dataset errors. + */ + public void tstMNISTBiIndexedCache() throws IOException { + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + + int ptsCnt = 40_000; + int featCnt = 28 * 28; + + Properties props = loadMNISTProperties(); + + Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnist(props.getProperty(PROP_TRAINING_IMAGES), props.getProperty(PROP_TRAINING_LABELS), new Random(123L), ptsCnt); + Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnist(props.getProperty(PROP_TEST_IMAGES), props.getProperty(PROP_TEST_LABELS), new Random(123L), 10_000); + + IgniteCache<BiIndex, Double> cache = createBiIndexedCache(); + + loadVectorsIntoBiIndexedCache(cache.getName(), trainingMnistStream.iterator(), featCnt + 1); + + ColumnDecisionTreeTrainer<GiniSplitCalculator.GiniData> trainer = + new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MOST_COMMON, ignite); + + System.out.println(">>> Training started"); + long before = System.currentTimeMillis(); + DecisionTreeModel mdl = trainer.train(new BiIndexedCacheColumnDecisionTreeTrainerInput(cache, new HashMap<>(), ptsCnt, featCnt)); + System.out.println(">>> Training finished in " + (System.currentTimeMillis() - before)); + + IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.errorsPercentage(); + Double accuracy = mse.apply(mdl, testMnistStream.map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity()); + System.out.println(">>> Errs percentage: " + accuracy); + + Assert.assertEquals(0, SplitCache.getOrCreate(ignite).size()); + Assert.assertEquals(0, FeaturesCache.getOrCreate(ignite).size()); + Assert.assertEquals(0, ContextCache.getOrCreate(ignite).size()); + Assert.assertEquals(0, ProjectionsCache.getOrCreate(ignite).size()); + } + + /** + * Run decision tree classifier on MNIST using sparse distributed matrix as a storage for dataset. + * To run this test rename this method so it starts from 'test'. + * + * @throws IOException In case of loading MNIST dataset errors. + */ + public void tstMNISTSparseDistributedMatrix() throws IOException { + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + + int ptsCnt = 30_000; + int featCnt = 28 * 28; + + Properties props = loadMNISTProperties(); + + Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnist(props.getProperty(PROP_TRAINING_IMAGES), props.getProperty(PROP_TRAINING_LABELS), new Random(123L), ptsCnt); + Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnist(props.getProperty(PROP_TEST_IMAGES), props.getProperty(PROP_TEST_LABELS), new Random(123L), 10_000); + + SparseDistributedMatrix m = new SparseDistributedMatrix(ptsCnt, featCnt + 1, StorageConstants.COLUMN_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE); + + SparseDistributedMatrixStorage sto = (SparseDistributedMatrixStorage)m.getStorage(); + + loadVectorsIntoSparseDistributedMatrixCache(sto.cache().getName(), sto.getUUID(), trainingMnistStream.iterator(), featCnt + 1); + + ColumnDecisionTreeTrainer<GiniSplitCalculator.GiniData> trainer = + new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MOST_COMMON, ignite); + + System.out.println(">>> Training started"); + long before = System.currentTimeMillis(); + DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, new HashMap<>())); + System.out.println(">>> Training finished in " + (System.currentTimeMillis() - before)); + + IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.errorsPercentage(); + Double accuracy = mse.apply(mdl, testMnistStream.map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity()); + System.out.println(">>> Errs percentage: " + accuracy); + + Assert.assertEquals(0, SplitCache.getOrCreate(ignite).size()); + Assert.assertEquals(0, FeaturesCache.getOrCreate(ignite).size()); + Assert.assertEquals(0, ContextCache.getOrCreate(ignite).size()); + Assert.assertEquals(0, ProjectionsCache.getOrCreate(ignite).size()); + } + + /** Load properties for MNIST tests. */ + private static Properties loadMNISTProperties() throws IOException { + Properties res = new Properties(); + + InputStream is = ColumnDecisionTreeTrainerBenchmark.class.getClassLoader().getResourceAsStream("manualrun/trees/columntrees.manualrun.properties"); + + res.load(is); + + return res; + } + + /** */ + private void testByGenStreamerLoad(int ptsPerReg, HashMap<Integer, Integer> catsInfo, + SplitDataGenerator<DenseLocalOnHeapVector> gen, Random rnd) { + + List<IgniteBiTuple<Integer, DenseLocalOnHeapVector>> lst = gen. + points(ptsPerReg, (i, rn) -> i). + collect(Collectors.toList()); + + int featCnt = gen.featuresCnt(); + + Collections.shuffle(lst, rnd); + + int numRegs = gen.regsCount(); + + SparseDistributedMatrix m = new SparseDistributedMatrix(numRegs * ptsPerReg, featCnt + 1, StorageConstants.COLUMN_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE); + + IgniteFunction<DoubleStream, Double> regCalc = s -> s.average().orElse(0.0); + + Map<Integer, List<LabeledVectorDouble>> byRegion = new HashMap<>(); + + SparseDistributedMatrixStorage sto = (SparseDistributedMatrixStorage)m.getStorage(); + long before = System.currentTimeMillis(); + System.out.println(">>> Batch loading started..."); + loadVectorsIntoSparseDistributedMatrixCache(sto.cache().getName(), sto.getUUID(), gen. + points(ptsPerReg, (i, rn) -> i).map(IgniteBiTuple::get2).iterator(), featCnt + 1); + System.out.println(">>> Batch loading took " + (System.currentTimeMillis() - before) + " ms."); + + for (IgniteBiTuple<Integer, DenseLocalOnHeapVector> bt : lst) { + byRegion.putIfAbsent(bt.get1(), new LinkedList<>()); + byRegion.get(bt.get1()).add(asLabeledVector(bt.get2().getStorage().data())); + } + + ColumnDecisionTreeTrainer<VarianceSplitCalculator.VarianceData> trainer = + new ColumnDecisionTreeTrainer<>(2, ContinuousSplitCalculators.VARIANCE, RegionCalculators.VARIANCE, regCalc, ignite); + + before = System.currentTimeMillis(); + DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, catsInfo)); + + System.out.println(">>> Took time(ms): " + (System.currentTimeMillis() - before)); + + byRegion.keySet().forEach(k -> { + LabeledVectorDouble sp = byRegion.get(k).get(0); + Tracer.showAscii(sp.vector()); + System.out.println("Prediction: " + mdl.predict(sp.vector()) + "label: " + sp.doubleLabel()); + assert mdl.predict(sp.vector()) == sp.doubleLabel(); + }); + } + + /** + * Test decision tree regression. + * To run this test rename this method so it starts from 'test'. + */ + public void tstF1() { + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + int ptsCnt = 10000; + Map<Integer, double[]> ranges = new HashMap<>(); + + ranges.put(0, new double[] {-100.0, 100.0}); + ranges.put(1, new double[] {-100.0, 100.0}); + ranges.put(2, new double[] {-100.0, 100.0}); + + int featCnt = 100; + double[] defRng = {-1.0, 1.0}; + + Vector[] trainVectors = vecsFromRanges(ranges, featCnt, defRng, new Random(123L), ptsCnt, f1); + + SparseDistributedMatrix m = new SparseDistributedMatrix(ptsCnt, featCnt + 1, StorageConstants.COLUMN_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE); + + SparseDistributedMatrixStorage sto = (SparseDistributedMatrixStorage)m.getStorage(); + + loadVectorsIntoSparseDistributedMatrixCache(sto.cache().getName(), sto.getUUID(), Arrays.stream(trainVectors).iterator(), featCnt + 1); + + IgniteFunction<DoubleStream, Double> regCalc = s -> s.average().orElse(0.0); + + ColumnDecisionTreeTrainer<VarianceSplitCalculator.VarianceData> trainer = + new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.VARIANCE, RegionCalculators.VARIANCE, regCalc, ignite); + + System.out.println(">>> Training started"); + long before = System.currentTimeMillis(); + DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, new HashMap<>())); + System.out.println(">>> Training finished in " + (System.currentTimeMillis() - before)); + + Vector[] testVectors = vecsFromRanges(ranges, featCnt, defRng, new Random(123L), 20, f1); + + IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.MSE(); + Double accuracy = mse.apply(mdl, Arrays.stream(testVectors).map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity()); + System.out.println(">>> MSE: " + accuracy); + } + + /** + * Load vectors into sparse distributed matrix. + * + * @param cacheName Name of cache where matrix is stored. + * @param uuid UUID of matrix. + * @param iter Iterator over vectors. + * @param vectorSize size of vectors. + */ + private void loadVectorsIntoSparseDistributedMatrixCache(String cacheName, UUID uuid, + Iterator<? extends org.apache.ignite.ml.math.Vector> iter, int vectorSize) { + try (IgniteDataStreamer<SparseMatrixKey, Map<Integer, Double>> streamer = + Ignition.localIgnite().dataStreamer(cacheName)) { + int sampleIdx = 0; + streamer.allowOverwrite(true); + + streamer.receiver(StreamTransformer.from((e, arg) -> { + Map<Integer, Double> val = e.getValue(); + + if (val == null) + val = new Int2DoubleOpenHashMap(); + + val.putAll((Map<Integer, Double>)arg[0]); + + e.setValue(val); + + return null; + })); + + // Feature index -> (sample index -> value) + Map<Integer, Map<Integer, Double>> batch = new HashMap<>(); + IntStream.range(0, vectorSize).forEach(i -> batch.put(i, new HashMap<>())); + int batchSize = 1000; + + while (iter.hasNext()) { + org.apache.ignite.ml.math.Vector next = iter.next(); + + for (int i = 0; i < vectorSize; i++) + batch.get(i).put(sampleIdx, next.getX(i)); + + System.out.println(sampleIdx); + if (sampleIdx % batchSize == 0) { + batch.keySet().forEach(fi -> streamer.addData(new SparseMatrixKey(fi, uuid, fi), batch.get(fi))); + IntStream.range(0, vectorSize).forEach(i -> batch.put(i, new HashMap<>())); + } + sampleIdx++; + } + if (sampleIdx % batchSize != 0) { + batch.keySet().forEach(fi -> streamer.addData(new SparseMatrixKey(fi, uuid, fi), batch.get(fi))); + IntStream.range(0, vectorSize).forEach(i -> batch.put(i, new HashMap<>())); + } + } + } + + /** + * Load vectors into bi-indexed cache. + * + * @param cacheName Name of cache. + * @param iter Iterator over vectors. + * @param vectorSize size of vectors. + */ + private void loadVectorsIntoBiIndexedCache(String cacheName, + Iterator<? extends org.apache.ignite.ml.math.Vector> iter, int vectorSize) { + try (IgniteDataStreamer<BiIndex, Double> streamer = + Ignition.localIgnite().dataStreamer(cacheName)) { + int sampleIdx = 0; + + streamer.perNodeBufferSize(10000); + + while (iter.hasNext()) { + org.apache.ignite.ml.math.Vector next = iter.next(); + + for (int i = 0; i < vectorSize; i++) + streamer.addData(new BiIndex(sampleIdx, i), next.getX(i)); + + sampleIdx++; + + if (sampleIdx % 1000 == 0) + System.out.println(">>> Loaded " + sampleIdx + " vectors."); + } + } + } + + /** + * Create bi-indexed cache for tests. + * + * @return Bi-indexed cache. + */ + private IgniteCache<BiIndex, Double> createBiIndexedCache() { + CacheConfiguration<BiIndex, Double> cfg = new CacheConfiguration<>(); + + // Write to primary. + cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC); + + // Atomic transactions only. + cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC); + + // No eviction. + cfg.setEvictionPolicy(null); + + // No copying of values. + cfg.setCopyOnRead(false); + + // Cache is partitioned. + cfg.setCacheMode(CacheMode.PARTITIONED); + + cfg.setBackups(0); + + cfg.setName("TMP_BI_INDEXED_CACHE"); + + return Ignition.localIgnite().getOrCreateCache(cfg); + } + + /** */ + private Vector[] vecsFromRanges(Map<Integer, double[]> ranges, int featCnt, double[] defRng, Random rnd, int ptsCnt, + Function<Vector, Double> f) { + int vs = featCnt + 1; + DenseLocalOnHeapVector[] res = new DenseLocalOnHeapVector[ptsCnt]; + for (int pt = 0; pt < ptsCnt; pt++) { + DenseLocalOnHeapVector v = new DenseLocalOnHeapVector(vs); + for (int i = 0; i < featCnt; i++) { + double[] range = ranges.getOrDefault(i, defRng); + double from = range[0]; + double to = range[1]; + double rng = to - from; + + v.setX(i, rnd.nextDouble() * rng); + } + v.setX(featCnt, f.apply(v)); + res[pt] = v; + } + + return res; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/test/resources/manualrun/trees/columntrees.manualrun.properties ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/resources/manualrun/trees/columntrees.manualrun.properties b/modules/ml/src/test/resources/manualrun/trees/columntrees.manualrun.properties new file mode 100644 index 0000000..7040010 --- /dev/null +++ b/modules/ml/src/test/resources/manualrun/trees/columntrees.manualrun.properties @@ -0,0 +1,5 @@ +# Paths to mnist dataset parts. +mnist.training.images=/path/to/train-images-idx3-ubyte +mnist.training.labels=/path/to/train-labels-idx1-ubyte +mnist.test.images=/path/to/t10k-images-idx3-ubyte +mnist.test.labels=/path/to/t10k-labels-idx1-ubyte \ No newline at end of file
