IGNITE-7007: Decision tree code cleanup This closes #3084
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/a29fe352 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/a29fe352 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/a29fe352 Branch: refs/heads/ignite-zk Commit: a29fe352de4fa3f66f471a4315fff097fe06c786 Parents: 3979e6a Author: artemmalykh <[email protected]> Authored: Fri Dec 1 20:54:59 2017 +0300 Committer: Yury Babak <[email protected]> Committed: Fri Dec 1 20:54:59 2017 +0300 ---------------------------------------------------------------------- .../ignite/ml/math/distributed/CacheUtils.java | 2 -- .../columnbased/ColumnDecisionTreeTrainer.java | 33 +++++++++++++------- .../org/apache/ignite/ml/util/MnistUtils.java | 17 +++++----- .../java/org/apache/ignite/ml/util/Utils.java | 6 ++-- .../ml/trees/ColumnDecisionTreeTrainerTest.java | 3 +- .../ColumnDecisionTreeTrainerBenchmark.java | 31 +++++++++--------- 6 files changed, 51 insertions(+), 41 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/a29fe352/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java index 6baa865..9ca167c 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java @@ -484,9 +484,7 @@ public class CacheUtils { m.put(k, v); } - long before = System.currentTimeMillis(); cache.putAll(m); - System.out.println("PutAll took: " + (System.currentTimeMillis() - before)); }); } http://git-wip-us.apache.org/repos/asf/ignite/blob/a29fe352/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java index 32e33f3..fec0a83 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java @@ -26,6 +26,7 @@ import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; @@ -37,12 +38,12 @@ import java.util.stream.Stream; import javax.cache.Cache; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; +import org.apache.ignite.IgniteLogger; import org.apache.ignite.Ignition; import org.apache.ignite.cache.CachePeekMode; import org.apache.ignite.cache.affinity.Affinity; import org.apache.ignite.cluster.ClusterNode; import org.apache.ignite.internal.processors.cache.CacheEntryImpl; -import org.apache.ignite.internal.util.typedef.X; import org.apache.ignite.lang.IgniteBiTuple; import org.apache.ignite.ml.Trainer; import org.apache.ignite.ml.math.Vector; @@ -115,6 +116,9 @@ public class ColumnDecisionTreeTrainer<D extends ContinuousRegionInfo> implement /** Ignite instance. */ private final Ignite ignite; + /** Logger */ + private final IgniteLogger log; + /** * Construct {@link ColumnDecisionTreeTrainer}. * @@ -135,6 +139,7 @@ public class ColumnDecisionTreeTrainer<D extends ContinuousRegionInfo> implement this.categoricalCalculatorProvider = categoricalCalculatorProvider; this.regCalc = regCalc; this.ignite = ignite; + this.log = ignite.log(); } /** @@ -329,7 +334,8 @@ public class ColumnDecisionTreeTrainer<D extends ContinuousRegionInfo> implement regsCnt++; - X.println(">>> Globally best: " + best.info + " idx time: " + findBestRegIdx + ", calculate best: " + findBestSplit + " fi: " + best.featureIdx + ", regs: " + regsCnt); + if (log.isDebugEnabled()) + log.debug("Globally best: " + best.info + " idx time: " + findBestRegIdx + ", calculate best: " + findBestSplit + " fi: " + best.featureIdx + ", regs: " + regsCnt); // Request bitset for split region. int ind = best.info.regionIndex(); @@ -361,8 +367,10 @@ public class ColumnDecisionTreeTrainer<D extends ContinuousRegionInfo> implement if (d > curDepth) { curDepth = d; - X.println(">>> Depth: " + curDepth); - X.println(">>> Cache size: " + prjsCache.size(CachePeekMode.PRIMARY)); + if (log.isDebugEnabled()) { + log.debug("Depth: " + curDepth); + log.debug("Cache size: " + prjsCache.size(CachePeekMode.PRIMARY)); + } } before = System.currentTimeMillis(); @@ -415,16 +423,19 @@ public class ColumnDecisionTreeTrainer<D extends ContinuousRegionInfo> implement }, bestRegsKeys); - X.println(">>> Update of projs cache took " + (System.currentTimeMillis() - before)); + if (log.isDebugEnabled()) + log.debug("Update of projections cache time: " + (System.currentTimeMillis() - before)); before = System.currentTimeMillis(); updateSplitCache(ind, rc, featuresCnt, ig -> i -> input.affinityKey(i, ig), uuid); - X.println(">>> Update of split cache took " + (System.currentTimeMillis() - before)); + if (log.isDebugEnabled()) + log.debug("Update of split cache time: " + (System.currentTimeMillis() - before)); } else { - X.println(">>> Best feature index: " + bestFeatureIdx + ", best infoGain " + bestInfoGain); + if (log.isDebugEnabled()) + log.debug("Best split [bestFeatureIdx=" + bestFeatureIdx + ", bestInfoGain=" + bestInfoGain + "]"); break; } } @@ -541,15 +552,15 @@ public class ColumnDecisionTreeTrainer<D extends ContinuousRegionInfo> implement double[] values = ctx.values(fIdx, ign); double[] labels = ctx.labels(); - IgniteBiTuple<Integer, Double> max = toCompare.entrySet().stream(). + Optional<IgniteBiTuple<Integer, Double>> max = toCompare.entrySet().stream(). map(ent -> { SplitInfo bestSplit = ctx.featureProcessor(fIdx).findBestSplit(ent.getValue(), values, labels, ent.getKey()); return new IgniteBiTuple<>(ent.getKey(), bestSplit != null ? bestSplit.infoGain() : Double.NEGATIVE_INFINITY); }). - max(Comparator.comparingDouble(IgniteBiTuple::get2)). - get(); + max(Comparator.comparingDouble(IgniteBiTuple::get2)); - return Stream.of(new CacheEntryImpl<>(e.getKey(), max)); + return max.<Stream<Cache.Entry<SplitKey, IgniteBiTuple<Integer, Double>>>> + map(objects -> Stream.of(new CacheEntryImpl<>(e.getKey(), objects))).orElseGet(Stream::empty); }, () -> IntStream.range(0, featuresCnt).mapToObj(fIdx -> SplitCache.key(fIdx, affinity.apply(ignite).apply(fIdx), trainingUUID)).collect(Collectors.toSet()) ); http://git-wip-us.apache.org/repos/asf/ignite/blob/a29fe352/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java index d69781e..a3f1d21 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java @@ -25,6 +25,7 @@ import java.util.Collections; import java.util.List; import java.util.Random; import java.util.stream.Stream; +import org.apache.ignite.IgniteException; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; /** @@ -40,14 +41,14 @@ public class MnistUtils { * @param rnd Random numbers generatror. * @param cnt Count of samples to read. * @return Stream of MNIST samples. - * @throws IOException + * @throws IOException In case of exception. */ public static Stream<DenseLocalOnHeapVector> mnist(String imagesPath, String labelsPath, Random rnd, int cnt) throws IOException { FileInputStream isImages = new FileInputStream(imagesPath); FileInputStream isLabels = new FileInputStream(labelsPath); - int magic = read4Bytes(isImages); // Skip magic number. + read4Bytes(isImages); // Skip magic number. int numOfImages = read4Bytes(isImages); int imgHeight = read4Bytes(isImages); int imgWidth = read4Bytes(isImages); @@ -57,10 +58,6 @@ public class MnistUtils { int numOfPixels = imgHeight * imgWidth; - System.out.println("Magic: " + magic); - System.out.println("Num of images: " + numOfImages); - System.out.println("Num of pixels: " + numOfPixels); - double[][] vecs = new double[numOfImages][numOfPixels + 1]; for (int imgNum = 0; imgNum < numOfImages; imgNum++) { @@ -88,7 +85,7 @@ public class MnistUtils { * @param outPath Path to output path. * @param rnd Random numbers generator. * @param cnt Count of samples to read. - * @throws IOException + * @throws IOException In case of exception. */ public static void asLIBSVM(String imagesPath, String labelsPath, String outPath, Random rnd, int cnt) throws IOException { @@ -109,7 +106,7 @@ public class MnistUtils { } catch (IOException e) { - e.printStackTrace(); + throw new IgniteException("Error while converting to LIBSVM."); } }); } @@ -119,9 +116,9 @@ public class MnistUtils { * Utility method for reading 4 bytes from input stream. * * @param is Input stream. - * @throws IOException + * @throws IOException In case of exception. */ private static int read4Bytes(FileInputStream is) throws IOException { return (is.read() << 24) | (is.read() << 16) | (is.read() << 8) | (is.read()); } -} +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/a29fe352/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java b/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java index bb779e3..847b1f1 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java @@ -22,6 +22,7 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import org.apache.ignite.IgniteException; /** * Class with various utility methods. @@ -34,8 +35,9 @@ public class Utils { * @param <T> Class of original object; * @return Deep copy of original object. */ + @SuppressWarnings({"unchecked"}) public static <T> T copy(T orig) { - Object obj = null; + Object obj; try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); @@ -50,7 +52,7 @@ public class Utils { obj = in.readObject(); } catch (IOException | ClassNotFoundException e) { - e.printStackTrace(); + throw new IgniteException("Couldn't copy the object."); } return (T)obj; http://git-wip-us.apache.org/repos/asf/ignite/blob/a29fe352/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java index 2b03b47..929ded9 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java @@ -26,6 +26,7 @@ import java.util.Random; import java.util.stream.Collectors; import java.util.stream.DoubleStream; import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.internal.util.typedef.X; import org.apache.ignite.lang.IgniteBiTuple; import org.apache.ignite.ml.math.StorageConstants; import org.apache.ignite.ml.math.Tracer; @@ -183,7 +184,7 @@ public class ColumnDecisionTreeTrainerTest extends BaseDecisionTreeTest { byRegion.keySet().forEach(k -> { LabeledVectorDouble sp = byRegion.get(k).get(0); Tracer.showAscii(sp.vector()); - System.out.println("Act: " + sp.label() + " " + " pred: " + mdl.predict(sp.vector())); + X.println("Actual and predicted vectors [act=" + sp.label() + " " + ", pred=" + mdl.predict(sp.vector()) + "]"); assert mdl.predict(sp.vector()) == sp.doubleLabel(); }); } http://git-wip-us.apache.org/repos/asf/ignite/blob/a29fe352/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java b/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java index 4e7cc24..7ca5d38 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java @@ -45,6 +45,7 @@ import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.configuration.IgniteConfiguration; import org.apache.ignite.internal.processors.cache.GridCacheProcessor; import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.internal.util.typedef.X; import org.apache.ignite.lang.IgniteBiTuple; import org.apache.ignite.ml.Model; import org.apache.ignite.ml.estimators.Estimators; @@ -163,14 +164,14 @@ public class ColumnDecisionTreeTrainerBenchmark extends BaseDecisionTreeTest { ColumnDecisionTreeTrainer<GiniSplitCalculator.GiniData> trainer = new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MOST_COMMON, ignite); - System.out.println(">>> Training started"); + X.println("Training started."); long before = System.currentTimeMillis(); DecisionTreeModel mdl = trainer.train(new BiIndexedCacheColumnDecisionTreeTrainerInput(cache, new HashMap<>(), ptsCnt, featCnt)); - System.out.println(">>> Training finished in " + (System.currentTimeMillis() - before)); + X.println("Training finished in " + (System.currentTimeMillis() - before)); IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.errorsPercentage(); Double accuracy = mse.apply(mdl, testMnistStream.map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity()); - System.out.println(">>> Errs percentage: " + accuracy); + X.println("Errors percentage: " + accuracy); Assert.assertEquals(0, SplitCache.getOrCreate(ignite).size()); Assert.assertEquals(0, FeaturesCache.getOrCreate(ignite).size()); @@ -204,14 +205,14 @@ public class ColumnDecisionTreeTrainerBenchmark extends BaseDecisionTreeTest { ColumnDecisionTreeTrainer<GiniSplitCalculator.GiniData> trainer = new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MOST_COMMON, ignite); - System.out.println(">>> Training started"); + X.println("Training started"); long before = System.currentTimeMillis(); DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, new HashMap<>())); - System.out.println(">>> Training finished in " + (System.currentTimeMillis() - before)); + X.println("Training finished in " + (System.currentTimeMillis() - before)); IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.errorsPercentage(); Double accuracy = mse.apply(mdl, testMnistStream.map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity()); - System.out.println(">>> Errs percentage: " + accuracy); + X.println("Errors percentage: " + accuracy); Assert.assertEquals(0, SplitCache.getOrCreate(ignite).size()); Assert.assertEquals(0, FeaturesCache.getOrCreate(ignite).size()); @@ -252,10 +253,10 @@ public class ColumnDecisionTreeTrainerBenchmark extends BaseDecisionTreeTest { SparseDistributedMatrixStorage sto = (SparseDistributedMatrixStorage)m.getStorage(); long before = System.currentTimeMillis(); - System.out.println(">>> Batch loading started..."); + X.println("Batch loading started..."); loadVectorsIntoSparseDistributedMatrixCache(sto.cache().getName(), sto.getUUID(), gen. points(ptsPerReg, (i, rn) -> i).map(IgniteBiTuple::get2).iterator(), featCnt + 1); - System.out.println(">>> Batch loading took " + (System.currentTimeMillis() - before) + " ms."); + X.println("Batch loading took " + (System.currentTimeMillis() - before) + " ms."); for (IgniteBiTuple<Integer, DenseLocalOnHeapVector> bt : lst) { byRegion.putIfAbsent(bt.get1(), new LinkedList<>()); @@ -268,12 +269,12 @@ public class ColumnDecisionTreeTrainerBenchmark extends BaseDecisionTreeTest { before = System.currentTimeMillis(); DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, catsInfo)); - System.out.println(">>> Took time(ms): " + (System.currentTimeMillis() - before)); + X.println("Training took: " + (System.currentTimeMillis() - before) + " ms."); byRegion.keySet().forEach(k -> { LabeledVectorDouble sp = byRegion.get(k).get(0); Tracer.showAscii(sp.vector()); - System.out.println("Prediction: " + mdl.predict(sp.vector()) + "label: " + sp.doubleLabel()); + X.println("Predicted value and label [pred=" + mdl.predict(sp.vector()) + ", label=" + sp.doubleLabel() + "]"); assert mdl.predict(sp.vector()) == sp.doubleLabel(); }); } @@ -307,16 +308,16 @@ public class ColumnDecisionTreeTrainerBenchmark extends BaseDecisionTreeTest { ColumnDecisionTreeTrainer<VarianceSplitCalculator.VarianceData> trainer = new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.VARIANCE, RegionCalculators.VARIANCE, regCalc, ignite); - System.out.println(">>> Training started"); + X.println("Training started."); long before = System.currentTimeMillis(); DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, new HashMap<>())); - System.out.println(">>> Training finished in " + (System.currentTimeMillis() - before)); + X.println("Training finished in: " + (System.currentTimeMillis() - before) + " ms."); Vector[] testVectors = vecsFromRanges(ranges, featCnt, defRng, new Random(123L), 20, f1); IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.MSE(); Double accuracy = mse.apply(mdl, Arrays.stream(testVectors).map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity()); - System.out.println(">>> MSE: " + accuracy); + X.println("MSE: " + accuracy); } /** @@ -358,7 +359,7 @@ public class ColumnDecisionTreeTrainerBenchmark extends BaseDecisionTreeTest { for (int i = 0; i < vectorSize; i++) batch.get(i).put(sampleIdx, next.getX(i)); - System.out.println(sampleIdx); + X.println("Sample index: " + sampleIdx); if (sampleIdx % batchSize == 0) { batch.keySet().forEach(fi -> streamer.addData(new SparseMatrixKey(fi, uuid, fi), batch.get(fi))); IntStream.range(0, vectorSize).forEach(i -> batch.put(i, new HashMap<>())); @@ -396,7 +397,7 @@ public class ColumnDecisionTreeTrainerBenchmark extends BaseDecisionTreeTest { sampleIdx++; if (sampleIdx % 1000 == 0) - System.out.println(">>> Loaded " + sampleIdx + " vectors."); + System.out.println("Loaded: " + sampleIdx + " vectors."); } } }
