Repository: ignite Updated Branches: refs/heads/master 25f2d1865 -> 44098bc6e
IGNITE-9064: [ML] Decision tree optimization this closes #4436 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/44098bc6 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/44098bc6 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/44098bc6 Branch: refs/heads/master Commit: 44098bc6e38ce9bbd4c191c3314b2123a60739d4 Parents: 25f2d18 Author: Alexey Platonov <[email protected]> Authored: Fri Aug 3 14:17:07 2018 +0300 Committer: Yury Babak <[email protected]> Committed: Fri Aug 3 14:17:07 2018 +0300 ---------------------------------------------------------------------- .../ml/composition/boosting/GDBTrainer.java | 2 +- .../dataset/impl/cache/CacheBasedDataset.java | 1 + .../dataset/impl/cache/util/ComputeUtils.java | 11 ++ .../org/apache/ignite/ml/tree/DecisionTree.java | 13 +- .../tree/DecisionTreeClassificationTrainer.java | 13 +- .../ml/tree/DecisionTreeRegressionTrainer.java | 14 +- .../GDBBinaryClassifierOnTreesTrainer.java | 17 +- .../boosting/GDBRegressionOnTreesTrainer.java | 17 +- .../ignite/ml/tree/data/DecisionTreeData.java | 51 ++++- .../ml/tree/data/DecisionTreeDataBuilder.java | 9 +- .../ignite/ml/tree/data/TreeDataIndex.java | 184 +++++++++++++++++++ .../impurity/ImpurityMeasureCalculator.java | 67 ++++++- .../gini/GiniImpurityMeasureCalculator.java | 67 ++++--- .../mse/MSEImpurityMeasureCalculator.java | 86 ++++++--- .../RandomForestClassifierTrainer.java | 13 +- .../RandomForestRegressionTrainer.java | 13 +- .../tree/randomforest/RandomForestTrainer.java | 4 + .../DecisionTreeClassificationTrainerTest.java | 25 ++- .../tree/DecisionTreeRegressionTrainerTest.java | 18 +- .../ml/tree/data/DecisionTreeDataTest.java | 21 ++- .../ignite/ml/tree/data/TreeDataIndexTest.java | 159 ++++++++++++++++ .../gini/GiniImpurityMeasureCalculatorTest.java | 27 ++- .../mse/MSEImpurityMeasureCalculatorTest.java | 21 ++- 23 files changed, 757 insertions(+), 96 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java index 6726892..8663d3d 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java @@ -165,7 +165,7 @@ abstract class GDBTrainer extends DatasetTrainer<Model<Vector, Double>, Double> try (Dataset<EmptyContext, DecisionTreeData> dataset = builder.build( new EmptyContextBuilder<>(), - new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor) + new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor, false) )) { IgniteBiTuple<Double, Long> meanTuple = dataset.compute( data -> { http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java index 67e0d56..e5eb483 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java @@ -144,6 +144,7 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose /** {@inheritDoc} */ @Override public void close() { datasetCache.destroy(); + ComputeUtils.removeData(ignite, datasetId); } /** http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java index 39b3703..a5cdd3b 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java @@ -184,6 +184,17 @@ public class ComputeUtils { } /** + * Remove data from local cache by Dataset ID. + * + * @param ignite Ingnite instance. + * @param datasetId Dataset ID. + */ + public static void removeData(Ignite ignite, UUID datasetId) { + ignite.cluster().nodeLocalMap().remove(String.format(DATA_STORAGE_KEY_TEMPLATE, datasetId)); + } + + + /** * Initializes partition {@code context} by loading it from a partition {@code upstream}. * * @param ignite Ignite instance. http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java index c1e3abf..270f14a 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java @@ -53,6 +53,9 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset /** Decision tree leaf builder. */ private final DecisionTreeLeafBuilder decisionTreeLeafBuilder; + /** Use index structure instead of using sorting while learning. */ + protected boolean useIndex = true; + /** * Constructs a new distributed decision tree trainer. * @@ -74,7 +77,7 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build( new EmptyContextBuilder<>(), - new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor) + new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor, useIndex) )) { return split(dataset, e -> true, 0, getImpurityMeasureCalculator(dataset)); } @@ -105,7 +108,7 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset if (deep >= maxDeep) return decisionTreeLeafBuilder.createLeafNode(dataset, filter); - StepFunction<T>[] criterionFunctions = calculateImpurityForAllColumns(dataset, filter, impurityCalc); + StepFunction<T>[] criterionFunctions = calculateImpurityForAllColumns(dataset, filter, impurityCalc, deep); if (criterionFunctions == null) return decisionTreeLeafBuilder.createLeafNode(dataset, filter); @@ -132,14 +135,14 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset * @return Array of impurity measure functions for all columns. */ private StepFunction<T>[] calculateImpurityForAllColumns(Dataset<EmptyContext, DecisionTreeData> dataset, - TreeFilter filter, ImpurityMeasureCalculator<T> impurityCalc) { + TreeFilter filter, ImpurityMeasureCalculator<T> impurityCalc, int depth) { StepFunction<T>[] result = dataset.compute( part -> { if (compressor != null) - return compressor.compress(impurityCalc.calculate(part.filter(filter))); + return compressor.compress(impurityCalc.calculate(part, filter, depth)); else - return impurityCalc.calculate(part.filter(filter)); + return impurityCalc.calculate(part, filter, depth); }, this::reduce ); http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java index 71e387f..f371334 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java @@ -85,12 +85,13 @@ public class DecisionTreeClassificationTrainer extends DecisionTree<GiniImpurity } /** - * Set up the step function compressor of decision tree. - * @param compressor The parameter value. - * @return Trainer with new compressor parameter value. + * Sets useIndex parameter and returns trainer instance. + * + * @param useIndex Use index. + * @return Decision tree trainer. */ - public DecisionTreeClassificationTrainer withCompressor(StepFunctionCompressor compressor){ - this.compressor = compressor; + public DecisionTreeClassificationTrainer withUseIndex(boolean useIndex) { + this.useIndex = useIndex; return this; } @@ -126,6 +127,6 @@ public class DecisionTreeClassificationTrainer extends DecisionTree<GiniImpurity for (Double lb : labels) encoder.put(lb, idx++); - return new GiniImpurityMeasureCalculator(encoder); + return new GiniImpurityMeasureCalculator(encoder, useIndex); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java index 2bf09d3..7446237 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java @@ -52,9 +52,21 @@ public class DecisionTreeRegressionTrainer extends DecisionTree<MSEImpurityMeasu super(maxDeep, minImpurityDecrease, compressor, new MeanDecisionTreeLeafBuilder()); } + /** + * Sets useIndex parameter and returns trainer instance. + * + * @param useIndex Use index. + * @return Decision tree trainer. + */ + public DecisionTreeRegressionTrainer withUseIndex(boolean useIndex) { + this.useIndex = useIndex; + return this; + } + /** {@inheritDoc} */ @Override ImpurityMeasureCalculator<MSEImpurityMeasure> getImpurityMeasureCalculator( Dataset<EmptyContext, DecisionTreeData> dataset) { - return new MSEImpurityMeasureCalculator(); + + return new MSEImpurityMeasureCalculator(useIndex); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java index 3789588..631e848 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBBinaryClassifierOnTreesTrainer.java @@ -30,9 +30,13 @@ import org.jetbrains.annotations.NotNull; public class GDBBinaryClassifierOnTreesTrainer extends GDBBinaryClassifierTrainer { /** Max depth. */ private final int maxDepth; + /** Min impurity decrease. */ private final double minImpurityDecrease; + /** Use index structure instead of using sorting while learning. */ + private boolean useIndex = true; + /** * Constructs instance of GDBBinaryClassifierOnTreesTrainer. * @@ -51,6 +55,17 @@ public class GDBBinaryClassifierOnTreesTrainer extends GDBBinaryClassifierTraine /** {@inheritDoc} */ @NotNull @Override protected DatasetTrainer<? extends Model<Vector, Double>, Double> buildBaseModelTrainer() { - return new DecisionTreeRegressionTrainer(maxDepth, minImpurityDecrease); + return new DecisionTreeRegressionTrainer(maxDepth, minImpurityDecrease).withUseIndex(useIndex); + } + + /** + * Sets useIndex parameter and returns trainer instance. + * + * @param useIndex Use index. + * @return Decision tree trainer. + */ + public GDBBinaryClassifierOnTreesTrainer withUseIndex(boolean useIndex) { + this.useIndex = useIndex; + return this; } } http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java index 50c5f8d..450dae3 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBRegressionOnTreesTrainer.java @@ -30,9 +30,13 @@ import org.jetbrains.annotations.NotNull; public class GDBRegressionOnTreesTrainer extends GDBRegressionTrainer { /** Max depth. */ private final int maxDepth; + /** Min impurity decrease. */ private final double minImpurityDecrease; + /** Use index structure instead of using sorting while learning. */ + private boolean useIndex = true; + /** * Constructs instance of GDBRegressionOnTreesTrainer. * @@ -51,6 +55,17 @@ public class GDBRegressionOnTreesTrainer extends GDBRegressionTrainer { /** {@inheritDoc} */ @NotNull @Override protected DatasetTrainer<? extends Model<Vector, Double>, Double> buildBaseModelTrainer() { - return new DecisionTreeRegressionTrainer(maxDepth, minImpurityDecrease); + return new DecisionTreeRegressionTrainer(maxDepth, minImpurityDecrease).withUseIndex(useIndex); + } + + /** + * Sets useIndex parameter and returns trainer instance. + * + * @param useIndex Use index. + * @return Decision tree trainer. + */ + public GDBRegressionOnTreesTrainer withUseIndex(boolean useIndex) { + this.useIndex = useIndex; + return this; } } http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java index 34deb46..c017e5c 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java @@ -17,6 +17,8 @@ package org.apache.ignite.ml.tree.data; +import java.util.ArrayList; +import java.util.List; import org.apache.ignite.ml.tree.TreeFilter; /** @@ -29,17 +31,29 @@ public class DecisionTreeData implements AutoCloseable { /** Vector with labels. */ private final double[] labels; + /** Indexes cache. */ + private final List<TreeDataIndex> indexesCache; + + /** Build index. */ + private final boolean buildIndex; + /** * Constructs a new instance of decision tree data. * * @param features Matrix with features. * @param labels Vector with labels. + * @param buildIdx Build index. */ - public DecisionTreeData(double[][] features, double[] labels) { + public DecisionTreeData(double[][] features, double[] labels, boolean buildIdx) { assert features.length == labels.length : "Features and labels have to be the same length"; this.features = features; this.labels = labels; + this.buildIndex = buildIdx; + + indexesCache = new ArrayList<>(); + if (buildIdx) + indexesCache.add(new TreeDataIndex(features, labels)); } /** @@ -69,7 +83,7 @@ public class DecisionTreeData implements AutoCloseable { } } - return new DecisionTreeData(newFeatures, newLabels); + return new DecisionTreeData(newFeatures, newLabels, buildIndex); } /** @@ -89,8 +103,10 @@ public class DecisionTreeData implements AutoCloseable { int i = from, j = to; while (i <= j) { - while (features[i][col] < pivot) i++; - while (features[j][col] > pivot) j--; + while (features[i][col] < pivot) + i++; + while (features[j][col] > pivot) + j--; if (i <= j) { double[] tmpFeature = features[i]; @@ -125,4 +141,31 @@ public class DecisionTreeData implements AutoCloseable { @Override public void close() { // Do nothing, GC will clean up. } + + /** + * Builds index in according to current tree depth and cached indexes in upper levels. Uses depth as key of cached + * index and replaces cached index with same key. + * + * @param depth Tree Depth. + * @param filter Filter. + */ + public TreeDataIndex createIndexByFilter(int depth, TreeFilter filter) { + assert depth >= 0 && depth <= indexesCache.size(); + + if (depth > 0 && depth <= indexesCache.size() - 1) { + for (int i = indexesCache.size() - 1; i >= depth; i--) + indexesCache.remove(i); + } + + if (depth == indexesCache.size()) { + if (depth == 0) + indexesCache.add(new TreeDataIndex(features, labels)); + else { + TreeDataIndex lastIndex = indexesCache.get(depth - 1); + indexesCache.add(lastIndex.filter(filter)); + } + } + + return indexesCache.get(depth); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java index 0ff2012..6678218 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java @@ -42,16 +42,21 @@ public class DecisionTreeDataBuilder<K, V, C extends Serializable> /** Function that extracts labels from an {@code upstream} data. */ private final IgniteBiFunction<K, V, Double> lbExtractor; + /** Build index. */ + private final boolean buildIndex; + /** * Constructs a new instance of decision tree data builder. * * @param featureExtractor Function that extracts features from an {@code upstream} data. * @param lbExtractor Function that extracts labels from an {@code upstream} data. + * @param buildIdx Build index. */ public DecisionTreeDataBuilder(IgniteBiFunction<K, V, Vector> featureExtractor, - IgniteBiFunction<K, V, Double> lbExtractor) { + IgniteBiFunction<K, V, Double> lbExtractor, boolean buildIdx) { this.featureExtractor = featureExtractor; this.lbExtractor = lbExtractor; + this.buildIndex = buildIdx; } /** {@inheritDoc} */ @@ -70,6 +75,6 @@ public class DecisionTreeDataBuilder<K, V, C extends Serializable> ptr++; } - return new DecisionTreeData(features, labels); + return new DecisionTreeData(features, labels, buildIndex); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/TreeDataIndex.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/TreeDataIndex.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/TreeDataIndex.java new file mode 100644 index 0000000..88ce190 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/TreeDataIndex.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.data; + +import java.util.Arrays; +import org.apache.ignite.ml.tree.TreeFilter; + +/** + * Index for representing sorted dataset rows for each features. + * It may be reused while decision tree learning at several levels through filter method. + */ +public class TreeDataIndex { + /** Index containing IDs of rows as if they is sorted by feature values. */ + private final int[][] index; + + /** Original features table. */ + private final double[][] features; + + /** Original labels. */ + private final double[] labels; + + /** + * Constructs an instance of TreeDataIndex. + * + * @param features Features. + * @param labels Labels. + */ + public TreeDataIndex(double[][] features, double[] labels) { + this.features = features; + this.labels = labels; + + int rows = features.length; + int cols = features.length == 0 ? 0 : features[0].length; + + double[][] featuresCp = new double[rows][cols]; + index = new int[rows][cols]; + for (int row = 0; row < rows; row++) { + Arrays.fill(index[row], row); + featuresCp[row] = Arrays.copyOf(features[row], cols); + } + + for (int col = 0; col < cols; col++) + sortIndex(featuresCp, col, 0, rows - 1); + } + + /** + * Constructs an instance of TreeDataIndex + * + * @param indexProj Index projection. + * @param features Features. + * @param labels Labels. + */ + private TreeDataIndex(int[][] indexProj, double[][] features, double[] labels) { + this.index = indexProj; + this.features = features; + this.labels = labels; + } + + /** + * Returns label for kth order statistic for target feature. + * + * @param k K. + * @param featureId Feature id. + * @return Label value. + */ + public double labelInSortedOrder(int k, int featureId) { + return labels[index[k][featureId]]; + } + + /** + * Returns vector of original features for kth order statistic for target feature. + * + * @param k K. + * @param featureId Feature id. + * @return Features vector. + */ + public double[] featuresInSortedOrder(int k, int featureId) { + return features[index[k][featureId]]; + } + + /** + * Returns feature value for kth order statistic for target feature. + * + * @param k K. + * @param featureId Feature id. + * @return Feature value. + */ + public double featureInSortedOrder(int k, int featureId) { + return featuresInSortedOrder(k, featureId)[featureId]; + } + + /** + * Creates projection of current index in according to {@link TreeFilter}. + * + * @param filter Filter. + * @return Projection of current index onto smaller index in according to rows filter. + */ + public TreeDataIndex filter(TreeFilter filter) { + int projSize = 0; + for (int i = 0; i < rowsCount(); i++) { + if (filter.test(featuresInSortedOrder(i, 0))) + projSize++; + } + + int[][] projection = new int[projSize][columnsCount()]; + for(int feature = 0; feature < columnsCount(); feature++) { + int ptr = 0; + for(int row = 0; row < rowsCount(); row++) { + if(filter.test(featuresInSortedOrder(row, feature))) + projection[ptr++][feature] = index[row][feature]; + } + } + + return new TreeDataIndex(projection, features, labels); + } + + /** + * @return count of rows in current index. + */ + public int rowsCount() { + return index.length; + } + + /** + * @return count of columns in current index. + */ + public int columnsCount() { + return rowsCount() == 0 ? 0 : index[0].length ; + } + + /** + * Constructs index structure in according to features table. + * + * @param features Features. + * @param col Column. + * @param from From. + * @param to To. + */ + private void sortIndex(double[][] features, int col, int from, int to) { + if (from < to) { + double pivot = features[(from + to) / 2][col]; + + int i = from, j = to; + + while (i <= j) { + while (features[i][col] < pivot) + i++; + while (features[j][col] > pivot) + j--; + + if (i <= j) { + double tmpFeature = features[i][col]; + features[i][col] = features[j][col]; + features[j][col] = tmpFeature; + + int tmpLb = index[i][col]; + index[i][col] = index[j][col]; + index[j][col] = tmpLb; + + i++; + j--; + } + } + + sortIndex(features, col, from, j); + sortIndex(features, col, i, to); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java index 2b69356..709f68e 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java @@ -18,7 +18,9 @@ package org.apache.ignite.ml.tree.impurity; import java.io.Serializable; +import org.apache.ignite.ml.tree.TreeFilter; import org.apache.ignite.ml.tree.data.DecisionTreeData; +import org.apache.ignite.ml.tree.data.TreeDataIndex; import org.apache.ignite.ml.tree.impurity.util.StepFunction; /** @@ -26,7 +28,19 @@ import org.apache.ignite.ml.tree.impurity.util.StepFunction; * * @param <T> Type of impurity measure. */ -public interface ImpurityMeasureCalculator<T extends ImpurityMeasure<T>> extends Serializable { +public abstract class ImpurityMeasureCalculator<T extends ImpurityMeasure<T>> implements Serializable { + /** Use index structure instead of using sorting while learning. */ + protected final boolean useIndex; + + /** + * Constructs an instance of ImpurityMeasureCalculator. + * + * @param useIndex Use index. + */ + public ImpurityMeasureCalculator(boolean useIndex) { + this.useIndex = useIndex; + } + /** * Calculates all impurity measures required required to find a best split and returns them as an array of * {@link StepFunction} (for every column). @@ -34,5 +48,54 @@ public interface ImpurityMeasureCalculator<T extends ImpurityMeasure<T>> extends * @param data Features and labels. * @return Impurity measures as an array of {@link StepFunction} (for every column). */ - public StepFunction<T>[] calculate(DecisionTreeData data); + public abstract StepFunction<T>[] calculate(DecisionTreeData data, TreeFilter filter, int depth); + + + /** + * Returns columns count in current dataset. + * + * @param data Data. + * @param idx Index. + * @return Columns count in current dataset. + */ + protected int columnsCount(DecisionTreeData data, TreeDataIndex idx) { + return useIndex ? idx.columnsCount() : data.getFeatures()[0].length; + } + + /** + * Returns rows count in current dataset. + * + * @param data Data. + * @param idx Index. + * @return rows count in current dataset + */ + protected int rowsCount(DecisionTreeData data, TreeDataIndex idx) { + return useIndex ? idx.rowsCount() : data.getFeatures().length; + } + + /** + * Returns label value in according to kth order statistic. + * + * @param data Data. + * @param idx Index. + * @param featureId Feature id. + * @param k K-th statistic. + * @return label value in according to kth order statistic + */ + protected double getLabelValue(DecisionTreeData data, TreeDataIndex idx, int featureId, int k) { + return useIndex ? idx.labelInSortedOrder(k, featureId) : data.getLabels()[k]; + } + + /** + * Returns feature value in according to kth order statistic. + * + * @param data Data. + * @param idx Index. + * @param featureId Feature id. + * @param k K-th statistic. + * @return feature value in according to kth order statistic. + */ + protected double getFeatureValue(DecisionTreeData data, TreeDataIndex idx, int featureId, int k) { + return useIndex ? idx.featureInSortedOrder(k, featureId) : data.getFeatures()[k][featureId]; + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculator.java index 0dd0a10..38b3097 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculator.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculator.java @@ -19,14 +19,16 @@ package org.apache.ignite.ml.tree.impurity.gini; import java.util.Arrays; import java.util.Map; +import org.apache.ignite.ml.tree.TreeFilter; import org.apache.ignite.ml.tree.data.DecisionTreeData; +import org.apache.ignite.ml.tree.data.TreeDataIndex; import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator; import org.apache.ignite.ml.tree.impurity.util.StepFunction; /** * Gini impurity measure calculator. */ -public class GiniImpurityMeasureCalculator implements ImpurityMeasureCalculator<GiniImpurityMeasure> { +public class GiniImpurityMeasureCalculator extends ImpurityMeasureCalculator<GiniImpurityMeasure> { /** */ private static final long serialVersionUID = -522995134128519679L; @@ -37,51 +39,70 @@ public class GiniImpurityMeasureCalculator implements ImpurityMeasureCalculator< * Constructs a new instance of Gini impurity measure calculator. * * @param lbEncoder Label encoder which defines integer value for every label class. + * @param useIndex Use index while calculate. */ - public GiniImpurityMeasureCalculator(Map<Double, Integer> lbEncoder) { + public GiniImpurityMeasureCalculator(Map<Double, Integer> lbEncoder, boolean useIndex) { + super(useIndex); this.lbEncoder = lbEncoder; } /** {@inheritDoc} */ @SuppressWarnings("unchecked") - @Override public StepFunction<GiniImpurityMeasure>[] calculate(DecisionTreeData data) { - double[][] features = data.getFeatures(); - double[] labels = data.getLabels(); + @Override public StepFunction<GiniImpurityMeasure>[] calculate(DecisionTreeData data, TreeFilter filter, int depth) { + TreeDataIndex index = null; + boolean canCalculate = false; - if (features.length > 0) { - StepFunction<GiniImpurityMeasure>[] res = new StepFunction[features[0].length]; + if (useIndex) { + index = data.createIndexByFilter(depth, filter); + canCalculate = index.rowsCount() > 0; + } + else { + data = data.filter(filter); + canCalculate = data.getFeatures().length > 0; + } - for (int col = 0; col < res.length; col++) { - data.sort(col); + if (canCalculate) { + int rowsCnt = rowsCount(data, index); + int colsCnt = columnsCount(data, index); - double[] x = new double[features.length + 1]; - GiniImpurityMeasure[] y = new GiniImpurityMeasure[features.length + 1]; + StepFunction<GiniImpurityMeasure>[] res = new StepFunction[colsCnt]; - int xPtr = 0, yPtr = 0; + long right[] = new long[lbEncoder.size()]; + for (int i = 0; i < rowsCnt; i++) { + double lb = getLabelValue(data, index, 0, i); + right[getLabelCode(lb)]++; + } - long[] left = new long[lbEncoder.size()]; - long[] right = new long[lbEncoder.size()]; + for (int col = 0; col < res.length; col++) { + if(!useIndex) + data.sort(col); - for (int i = 0; i < labels.length; i++) - right[getLabelCode(labels[i])]++; + double[] x = new double[rowsCnt + 1]; + GiniImpurityMeasure[] y = new GiniImpurityMeasure[rowsCnt + 1]; + long[] left = new long[lbEncoder.size()]; + long[] rightCopy = Arrays.copyOf(right, right.length); + + int xPtr = 0, yPtr = 0; x[xPtr++] = Double.NEGATIVE_INFINITY; y[yPtr++] = new GiniImpurityMeasure( Arrays.copyOf(left, left.length), - Arrays.copyOf(right, right.length) + Arrays.copyOf(rightCopy, rightCopy.length) ); - for (int i = 0; i < features.length; i++) { - left[getLabelCode(labels[i])]++; - right[getLabelCode(labels[i])]--; + for (int i = 0; i < rowsCnt; i++) { + double lb = getLabelValue(data, index, col, i); + left[getLabelCode(lb)]++; + rightCopy[getLabelCode(lb)]--; - if (i < (features.length - 1) && features[i + 1][col] == features[i][col]) + double featureVal = getFeatureValue(data, index, col, i); + if (i < (rowsCnt - 1) && getFeatureValue(data, index, col, i + 1) == featureVal) continue; - x[xPtr++] = features[i][col]; + x[xPtr++] = featureVal; y[yPtr++] = new GiniImpurityMeasure( Arrays.copyOf(left, left.length), - Arrays.copyOf(right, right.length) + Arrays.copyOf(rightCopy, rightCopy.length) ); } http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculator.java index cb5019c..1788737 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculator.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculator.java @@ -17,56 +17,92 @@ package org.apache.ignite.ml.tree.impurity.mse; +import org.apache.ignite.ml.tree.TreeFilter; import org.apache.ignite.ml.tree.data.DecisionTreeData; +import org.apache.ignite.ml.tree.data.TreeDataIndex; import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator; import org.apache.ignite.ml.tree.impurity.util.StepFunction; /** * Meas squared error (variance) impurity measure calculator. */ -public class MSEImpurityMeasureCalculator implements ImpurityMeasureCalculator<MSEImpurityMeasure> { +public class MSEImpurityMeasureCalculator extends ImpurityMeasureCalculator<MSEImpurityMeasure> { /** */ private static final long serialVersionUID = 288747414953756824L; + /** + * Constructs an instance of MSEImpurityMeasureCalculator. + * + * @param useIndex Use index while calculate. + */ + public MSEImpurityMeasureCalculator(boolean useIndex) { + super(useIndex); + } + /** {@inheritDoc} */ - @SuppressWarnings("unchecked") - @Override public StepFunction<MSEImpurityMeasure>[] calculate(DecisionTreeData data) { - double[][] features = data.getFeatures(); - double[] labels = data.getLabels(); + @Override public StepFunction<MSEImpurityMeasure>[] calculate(DecisionTreeData data, TreeFilter filter, int depth) { + TreeDataIndex index = null; + boolean canCalculate = false; + + if (useIndex) { + index = data.createIndexByFilter(depth, filter); + canCalculate = index.rowsCount() > 0; + } + else { + data = data.filter(filter); + canCalculate = data.getFeatures().length > 0; + } + + if (canCalculate) { + int rowsCnt = rowsCount(data, index); + int colsCnt = columnsCount(data, index); - if (features.length > 0) { - StepFunction<MSEImpurityMeasure>[] res = new StepFunction[features[0].length]; + @SuppressWarnings("unchecked") + StepFunction<MSEImpurityMeasure>[] res = new StepFunction[colsCnt]; + + double rightYOriginal = 0; + double rightY2Original = 0; + for (int i = 0; i < rowsCnt; i++) { + double lbVal = getLabelValue(data, index, 0, i); + + rightYOriginal += lbVal; + rightY2Original += Math.pow(lbVal, 2); + } for (int col = 0; col < res.length; col++) { - data.sort(col); + if (!useIndex) + data.sort(col); - double[] x = new double[features.length + 1]; - MSEImpurityMeasure[] y = new MSEImpurityMeasure[features.length + 1]; + double[] x = new double[rowsCnt + 1]; + MSEImpurityMeasure[] y = new MSEImpurityMeasure[rowsCnt + 1]; x[0] = Double.NEGATIVE_INFINITY; - for (int leftSize = 0; leftSize <= features.length; leftSize++) { - double leftY = 0; - double leftY2 = 0; - double rightY = 0; - double rightY2 = 0; + double leftY = 0; + double leftY2 = 0; + double rightY = rightYOriginal; + double rightY2 = rightY2Original; - for (int i = 0; i < leftSize; i++) { - leftY += labels[i]; - leftY2 += Math.pow(labels[i], 2); - } + int leftSize = 0; + for (int i = 0; i <= rowsCnt; i++) { + if (leftSize > 0) { + double lblVal = getLabelValue(data, index, col, i - 1); + + leftY += lblVal; + leftY2 += Math.pow(lblVal, 2); - for (int i = leftSize; i < features.length; i++) { - rightY += labels[i]; - rightY2 += Math.pow(labels[i], 2); + rightY -= lblVal; + rightY2 -= Math.pow(lblVal, 2); } - if (leftSize < features.length) - x[leftSize + 1] = features[leftSize][col]; + if (leftSize < rowsCnt) + x[leftSize + 1] = getFeatureValue(data, index, col, i); y[leftSize] = new MSEImpurityMeasure( - leftY, leftY2, leftSize, rightY, rightY2, features.length - leftSize + leftY, leftY2, leftSize, rightY, rightY2, rowsCnt - leftSize ); + + leftSize++; } res[col] = new StepFunction<>(x, y); http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java index daba4fa..bbbb2a9 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java @@ -72,6 +72,17 @@ public class RandomForestClassifierTrainer extends RandomForestTrainer { /** {@inheritDoc} */ @Override protected DatasetTrainer<DecisionTreeNode, Double> buildDatasetTrainerForModel() { - return new DecisionTreeClassificationTrainer(maxDeep, minImpurityDecrease); + return new DecisionTreeClassificationTrainer(maxDeep, minImpurityDecrease).withUseIndex(useIndex); + } + + /** + * Sets useIndex parameter and returns trainer instance. + * + * @param useIndex Use index. + * @return Decision tree trainer. + */ + public RandomForestClassifierTrainer withUseIndex(boolean useIndex) { + this.useIndex = useIndex; + return this; } } http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainer.java index 5b41b2c..009fff2 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainer.java @@ -73,6 +73,17 @@ public class RandomForestRegressionTrainer extends RandomForestTrainer { /** {@inheritDoc} */ @Override protected DatasetTrainer<DecisionTreeNode, Double> buildDatasetTrainerForModel() { - return new DecisionTreeRegressionTrainer(maxDeep, minImpurityDecrease); + return new DecisionTreeRegressionTrainer(maxDeep, minImpurityDecrease).withUseIndex(useIndex); + } + + /** + * Sets useIndex parameter and returns trainer instance. + * + * @param useIndex Use index. + * @return Decision tree trainer. + */ + public RandomForestRegressionTrainer withUseIndex(boolean useIndex) { + this.useIndex = useIndex; + return this; } } http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java index b5ecaed..8608f09 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java @@ -26,9 +26,13 @@ import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggrega public abstract class RandomForestTrainer extends BaggingModelTrainer { /** Max decision tree deep. */ protected final int maxDeep; + /** Min impurity decrease. */ protected final double minImpurityDecrease; + /** Use index structure instead of using sorting while decision tree learning. */ + protected boolean useIndex = false; + /** * Constructs new instance of BaggingModelTrainer. * http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java index de40b48..c84da12 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java @@ -40,15 +40,21 @@ public class DecisionTreeClassificationTrainerTest { private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7}; /** Number of partitions. */ - @Parameterized.Parameter + @Parameterized.Parameter(0) public int parts; + /** Use index [= 1 if true]. */ + @Parameterized.Parameter(1) + public int useIndex; - @Parameterized.Parameters(name = "Data divided on {0} partitions") + /** Test parameters. */ + @Parameterized.Parameters(name = "Data divided on {0} partitions. Use index = {1}.") public static Iterable<Integer[]> data() { List<Integer[]> res = new ArrayList<>(); - for (int part : partsToBeTested) - res.add(new Integer[] {part}); + for (int i = 0; i < 2; i++) { + for (int part : partsToBeTested) + res.add(new Integer[] {part, i}); + } return res; } @@ -63,10 +69,11 @@ public class DecisionTreeClassificationTrainerTest { Random rnd = new Random(0); for (int i = 0; i < size; i++) { double x = rnd.nextDouble() - 0.5; - data.put(i, new double[]{x, x > 0 ? 1 : 0}); + data.put(i, new double[] {x, x > 0 ? 1 : 0}); } - DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0); + DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0) + .withUseIndex(useIndex == 1); DecisionTreeNode tree = trainer.fit( data, @@ -77,15 +84,15 @@ public class DecisionTreeClassificationTrainerTest { assertTrue(tree instanceof DecisionTreeConditionalNode); - DecisionTreeConditionalNode node = (DecisionTreeConditionalNode) tree; + DecisionTreeConditionalNode node = (DecisionTreeConditionalNode)tree; assertEquals(0, node.getThreshold(), 1e-3); assertTrue(node.getThenNode() instanceof DecisionTreeLeafNode); assertTrue(node.getElseNode() instanceof DecisionTreeLeafNode); - DecisionTreeLeafNode thenNode = (DecisionTreeLeafNode) node.getThenNode(); - DecisionTreeLeafNode elseNode = (DecisionTreeLeafNode) node.getElseNode(); + DecisionTreeLeafNode thenNode = (DecisionTreeLeafNode)node.getThenNode(); + DecisionTreeLeafNode elseNode = (DecisionTreeLeafNode)node.getElseNode(); assertEquals(1, thenNode.getVal(), 1e-10); assertEquals(0, elseNode.getVal(), 1e-10); http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java index f69da4f..4e64925 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java @@ -40,14 +40,21 @@ public class DecisionTreeRegressionTrainerTest { private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7}; /** Number of partitions. */ - @Parameterized.Parameter + @Parameterized.Parameter(0) public int parts; - @Parameterized.Parameters(name = "Data divided on {0} partitions") + /** Use index [= 1 if true]. */ + @Parameterized.Parameter(1) + public int useIndex; + + /** Test parameters. */ + @Parameterized.Parameters(name = "Data divided on {0} partitions. Use index = {1}.") public static Iterable<Integer[]> data() { List<Integer[]> res = new ArrayList<>(); - for (int part : partsToBeTested) - res.add(new Integer[] {part}); + for (int i = 0; i < 2; i++) { + for (int part : partsToBeTested) + res.add(new Integer[] {part, i}); + } return res; } @@ -65,7 +72,8 @@ public class DecisionTreeRegressionTrainerTest { data.put(i, new double[]{x, x > 0 ? 1 : 0}); } - DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(1, 0); + DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(1, 0) + .withUseIndex(useIndex == 1); DecisionTreeNode tree = trainer.fit( data, http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/DecisionTreeDataTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/DecisionTreeDataTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/DecisionTreeDataTest.java index 0c89d4e..4ee717a 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/DecisionTreeDataTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/DecisionTreeDataTest.java @@ -17,21 +17,38 @@ package org.apache.ignite.ml.tree.data; +import java.util.Arrays; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import static org.junit.Assert.assertArrayEquals; /** * Tests for {@link DecisionTreeData}. */ +@RunWith(Parameterized.class) public class DecisionTreeDataTest { + /** Parameters. */ + @Parameterized.Parameters(name = "Use index {0}") + public static Iterable<Boolean[]> data() { + return Arrays.asList( + new Boolean[] {true}, + new Boolean[] {false} + ); + } + + /** Use index. */ + @Parameterized.Parameter + public boolean useIndex; + /** */ @Test public void testFilter() { double[][] features = new double[][]{{0}, {1}, {2}, {3}, {4}, {5}}; double[] labels = new double[]{0, 1, 2, 3, 4, 5}; - DecisionTreeData data = new DecisionTreeData(features, labels); + DecisionTreeData data = new DecisionTreeData(features, labels, useIndex); DecisionTreeData filteredData = data.filter(obj -> obj[0] > 2); assertArrayEquals(new double[][]{{3}, {4}, {5}}, filteredData.getFeatures()); @@ -44,7 +61,7 @@ public class DecisionTreeDataTest { double[][] features = new double[][]{{4, 1}, {3, 3}, {2, 0}, {1, 4}, {0, 2}}; double[] labels = new double[]{0, 1, 2, 3, 4}; - DecisionTreeData data = new DecisionTreeData(features, labels); + DecisionTreeData data = new DecisionTreeData(features, labels, useIndex); data.sort(0); http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/TreeDataIndexTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/TreeDataIndexTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/TreeDataIndexTest.java new file mode 100644 index 0000000..78bdfdf --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/TreeDataIndexTest.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.tree.data; + +import org.apache.ignite.ml.tree.TreeFilter; +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** + * Test for {@link TreeDataIndex}. + */ +public class TreeDataIndexTest { + /** */ + private double[][] features = { + {1., 2., 3., 4.}, + {2., 3., 4., 1.}, + {3., 4., 1., 2.}, + {4., 1., 2., 3.} + }; + + /** */ + private double[] labels = {1., 2., 3, 4.}; + + /** */ + private double[][] labelsInSortedOrder = { + {1., 4., 3., 2.}, + {2., 1., 4., 3.}, + {3., 2., 1., 4.}, + {4., 3., 2., 1.} + }; + + /** */ + private double[][][] featuresInSortedOrder = { + { + {1., 2., 3., 4.}, + {4., 1., 2., 3.}, + {3., 4., 1., 2.}, + {2., 3., 4., 1.}, + }, + { + {2., 3., 4., 1.}, + {1., 2., 3., 4.}, + {4., 1., 2., 3.}, + {3., 4., 1., 2.}, + }, + { + {3., 4., 1., 2.}, + {2., 3., 4., 1.}, + {1., 2., 3., 4.}, + {4., 1., 2., 3.}, + }, + { + {4., 1., 2., 3.}, + {3., 4., 1., 2.}, + {2., 3., 4., 1.}, + {1., 2., 3., 4.}, + } + }; + + /** */ + private TreeDataIndex index = new TreeDataIndex(features, labels); + + /** */ + @Test + public void labelInSortedOrderTest() { + assertEquals(features.length, index.rowsCount()); + assertEquals(features[0].length, index.columnsCount()); + + for (int k = 0; k < index.rowsCount(); k++) { + for (int featureId = 0; featureId < index.columnsCount(); featureId++) + assertEquals(labelsInSortedOrder[k][featureId], index.labelInSortedOrder(k, featureId), 0.01); + } + } + + /** */ + @Test + public void featuresInSortedOrderTest() { + assertEquals(features.length, index.rowsCount()); + assertEquals(features[0].length, index.columnsCount()); + + for (int k = 0; k < index.rowsCount(); k++) { + for (int featureId = 0; featureId < index.columnsCount(); featureId++) + assertArrayEquals(featuresInSortedOrder[k][featureId], index.featuresInSortedOrder(k, featureId), 0.01); + } + } + + /** */ + @Test + public void featureInSortedOrderTest() { + assertEquals(features.length, index.rowsCount()); + assertEquals(features[0].length, index.columnsCount()); + + for (int k = 0; k < index.rowsCount(); k++) { + for (int featureId = 0; featureId < index.columnsCount(); featureId++) + assertEquals((double)k + 1, index.featureInSortedOrder(k, featureId), 0.01); + } + } + + /** */ + @Test + public void filterTest() { + TreeFilter filter1 = features -> features[0] > 2; + TreeFilter filter2 = features -> features[1] > 2; + TreeFilter filterAnd = filter1.and(features -> features[1] > 2); + + TreeDataIndex filtered1 = index.filter(filter1); + TreeDataIndex filtered2 = filtered1.filter(filter2); + TreeDataIndex filtered3 = index.filter(filterAnd); + + assertEquals(2, filtered1.rowsCount()); + assertEquals(4, filtered1.columnsCount()); + assertEquals(1, filtered2.rowsCount()); + assertEquals(4, filtered2.columnsCount()); + assertEquals(1, filtered3.rowsCount()); + assertEquals(4, filtered3.columnsCount()); + + double[] obj1 = {3, 4, 1, 2}; + double[] obj2 = {4, 1, 2, 3}; + double[][] restObjs = new double[][] {obj1, obj2}; + int[][] restObjIndxInSortedOrderPerFeatures = new int[][] { + {0, 1}, //feature 0 + {1, 0}, //feature 1 + {0, 1}, //feature 2 + {0, 1}, //feature 3 + }; + + for (int featureId = 0; featureId < filtered1.columnsCount(); featureId++) { + for (int k = 0; k < filtered1.rowsCount(); k++) { + int objId = restObjIndxInSortedOrderPerFeatures[featureId][k]; + double[] obj = restObjs[objId]; + assertArrayEquals(obj, filtered1.featuresInSortedOrder(k, featureId), 0.01); + } + } + + for (int featureId = 0; featureId < filtered2.columnsCount(); featureId++) { + for (int k = 0; k < filtered2.rowsCount(); k++) { + assertArrayEquals(obj1, filtered2.featuresInSortedOrder(k, featureId), 0.01); + assertArrayEquals(obj1, filtered3.featuresInSortedOrder(k, featureId), 0.01); + } + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculatorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculatorTest.java index afd81e8..a328bd7 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculatorTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculatorTest.java @@ -17,11 +17,14 @@ package org.apache.ignite.ml.tree.impurity.gini; +import java.util.Arrays; import java.util.HashMap; import java.util.Map; import org.apache.ignite.ml.tree.data.DecisionTreeData; import org.apache.ignite.ml.tree.impurity.util.StepFunction; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import static junit.framework.TestCase.assertEquals; import static org.junit.Assert.assertArrayEquals; @@ -29,7 +32,21 @@ import static org.junit.Assert.assertArrayEquals; /** * Tests for {@link GiniImpurityMeasureCalculator}. */ +@RunWith(Parameterized.class) public class GiniImpurityMeasureCalculatorTest { + /** Parameters. */ + @Parameterized.Parameters(name = "Use index {0}") + public static Iterable<Boolean[]> data() { + return Arrays.asList( + new Boolean[] {true}, + new Boolean[] {false} + ); + } + + /** Use index. */ + @Parameterized.Parameter + public boolean useIndex; + /** */ @Test public void testCalculate() { @@ -39,9 +56,9 @@ public class GiniImpurityMeasureCalculatorTest { Map<Double, Integer> encoder = new HashMap<>(); encoder.put(0.0, 0); encoder.put(1.0, 1); - GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder); + GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder, useIndex); - StepFunction<GiniImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels)); + StepFunction<GiniImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIndex), fs -> true, 0); assertEquals(2, impurity.length); @@ -71,9 +88,9 @@ public class GiniImpurityMeasureCalculatorTest { Map<Double, Integer> encoder = new HashMap<>(); encoder.put(0.0, 0); encoder.put(1.0, 1); - GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder); + GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder, useIndex); - StepFunction<GiniImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels)); + StepFunction<GiniImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIndex), fs -> true, 0); assertEquals(1, impurity.length); @@ -94,7 +111,7 @@ public class GiniImpurityMeasureCalculatorTest { encoder.put(1.0, 1); encoder.put(2.0, 2); - GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder); + GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder, useIndex); assertEquals(0, calculator.getLabelCode(0.0)); assertEquals(1, calculator.getLabelCode(1.0)); http://git-wip-us.apache.org/repos/asf/ignite/blob/44098bc6/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculatorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculatorTest.java index 510c18f..82b3805 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculatorTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculatorTest.java @@ -17,9 +17,12 @@ package org.apache.ignite.ml.tree.impurity.mse; +import java.util.Arrays; import org.apache.ignite.ml.tree.data.DecisionTreeData; import org.apache.ignite.ml.tree.impurity.util.StepFunction; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import static junit.framework.TestCase.assertEquals; import static org.junit.Assert.assertArrayEquals; @@ -27,16 +30,30 @@ import static org.junit.Assert.assertArrayEquals; /** * Tests for {@link MSEImpurityMeasureCalculator}. */ +@RunWith(Parameterized.class) public class MSEImpurityMeasureCalculatorTest { + /** Parameters. */ + @Parameterized.Parameters(name = "Use index {0}") + public static Iterable<Boolean[]> data() { + return Arrays.asList( + new Boolean[] {true}, + new Boolean[] {false} + ); + } + + /** Use index. */ + @Parameterized.Parameter + public boolean useIndex; + /** */ @Test public void testCalculate() { double[][] data = new double[][]{{0, 2}, {1, 1}, {2, 0}, {3, 3}}; double[] labels = new double[]{1, 2, 2, 1}; - MSEImpurityMeasureCalculator calculator = new MSEImpurityMeasureCalculator(); + MSEImpurityMeasureCalculator calculator = new MSEImpurityMeasureCalculator(useIndex); - StepFunction<MSEImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels)); + StepFunction<MSEImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIndex), fs -> true, 0); assertEquals(2, impurity.length);
