IGNITE-7590: fixed tree example this closes #3459
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/a9d40a70 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/a9d40a70 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/a9d40a70 Branch: refs/heads/ignite-7485-2 Commit: a9d40a708bffeb2c3762e5cfbc00f9c33b02cd4d Parents: 2e43749 Author: artemmalykh <amal...@gridgain.com> Authored: Thu Feb 1 12:43:02 2018 +0300 Committer: Yury Babak <yba...@gridgain.com> Committed: Thu Feb 1 12:43:02 2018 +0300 ---------------------------------------------------------------------- .../examples/ml/MLExamplesCommonArgs.java | 31 ++ .../examples/ml/trees/DecisionTreesExample.java | 354 +++++++++++++++++++ .../ignite/examples/ml/trees/MNISTExample.java | 261 -------------- .../testsuites/IgniteExamplesMLTestSuite.java | 5 +- 4 files changed, 388 insertions(+), 263 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/a9d40a70/examples/src/main/java/org/apache/ignite/examples/ml/MLExamplesCommonArgs.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/MLExamplesCommonArgs.java b/examples/src/main/java/org/apache/ignite/examples/ml/MLExamplesCommonArgs.java new file mode 100644 index 0000000..701894b --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/MLExamplesCommonArgs.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml; + +/** + * Some common arguments for examples in ML module. + */ +public class MLExamplesCommonArgs { + /** + * Unattended argument. + */ + public static String UNATTENDED = "unattended"; + + /** Empty args for ML examples. */ + public static final String[] EMPTY_ARGS_ML = new String[] {"--" + UNATTENDED}; +} http://git-wip-us.apache.org/repos/asf/ignite/blob/a9d40a70/examples/src/main/java/org/apache/ignite/examples/ml/trees/DecisionTreesExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/trees/DecisionTreesExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/trees/DecisionTreesExample.java new file mode 100644 index 0000000..3860e8e --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/trees/DecisionTreesExample.java @@ -0,0 +1,354 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.trees; + +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.net.URL; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Random; +import java.util.Scanner; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import java.util.zip.GZIPInputStream; +import org.apache.commons.cli.BasicParser; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.CommandLineParser; +import org.apache.commons.cli.Option; +import org.apache.commons.cli.OptionBuilder; +import org.apache.commons.cli.Options; +import org.apache.commons.cli.ParseException; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.IgniteDataStreamer; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.CacheWriteSynchronizationMode; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.examples.ExampleNodeStartup; +import org.apache.ignite.examples.ml.MLExamplesCommonArgs; +import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.lang.IgniteBiTuple; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.estimators.Estimators; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.functions.IgniteTriFunction; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.trees.models.DecisionTreeModel; +import org.apache.ignite.ml.trees.trainers.columnbased.BiIndex; +import org.apache.ignite.ml.trees.trainers.columnbased.BiIndexedCacheColumnDecisionTreeTrainerInput; +import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer; +import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.ContinuousSplitCalculators; +import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator; +import org.apache.ignite.ml.trees.trainers.columnbased.regcalcs.RegionCalculators; +import org.apache.ignite.ml.util.MnistUtils; +import org.jetbrains.annotations.NotNull; + +/** + * <p> + * Example of usage of decision trees algorithm for MNIST dataset + * (it can be found here: http://yann.lecun.com/exdb/mnist/). </p> + * <p> + * Remote nodes should always be started with special configuration file which + * enables P2P class loading: {@code 'ignite.{sh|bat} examples/config/example-ignite.xml'}.</p> + * <p> + * Alternatively you can run {@link ExampleNodeStartup} in another JVM which will start node + * with {@code examples/config/example-ignite.xml} configuration.</p> + * <p> + * It is recommended to start at least one node prior to launching this example if you intend + * to run it with default memory settings.</p> + * <p> + * This example should be run with program arguments, for example + * -cfg examples/config/example-ignite.xml.</p> + * <p> + * -cfg specifies path to a config path.</p> + */ +public class DecisionTreesExample { + /** Name of parameter specifying path of Ignite config. */ + private static final String CONFIG = "cfg"; + + /** Default config path. */ + private static final String DEFAULT_CONFIG = "examples/config/example-ignite.xml"; + + /** + * Folder in which MNIST dataset is expected. + */ + private static String MNIST_DIR = "examples/src/main/resources/"; + + /** + * Key for MNIST training images. + */ + private static String MNIST_TRAIN_IMAGES = "train_images"; + + /** + * Key for MNIST training labels. + */ + private static String MNIST_TRAIN_LABELS = "train_labels"; + + /** + * Key for MNIST test images. + */ + private static String MNIST_TEST_IMAGES = "test_images"; + + /** + * Key for MNIST test labels. + */ + private static String MNIST_TEST_LABELS = "test_labels"; + + /** + * Launches example. + * + * @param args Program arguments. + */ + public static void main(String[] args) throws IOException { + System.out.println(">>> Decision trees example started."); + + String igniteCfgPath; + + CommandLineParser parser = new BasicParser(); + + String trainingImagesPath; + String trainingLabelsPath; + + String testImagesPath; + String testLabelsPath; + + Map<String, String> mnistPaths = new HashMap<>(); + + mnistPaths.put(MNIST_TRAIN_IMAGES, "train-images-idx3-ubyte"); + mnistPaths.put(MNIST_TRAIN_LABELS, "train-labels-idx1-ubyte"); + mnistPaths.put(MNIST_TEST_IMAGES, "t10k-images-idx3-ubyte"); + mnistPaths.put(MNIST_TEST_LABELS, "t10k-labels-idx1-ubyte"); + + try { + // Parse the command line arguments. + CommandLine line = parser.parse(buildOptions(), args); + + if (line.hasOption(MLExamplesCommonArgs.UNATTENDED)) { + System.out.println(">>> Skipped example execution because 'unattended' mode is used."); + System.out.println(">>> Decision trees example finished."); + return; + } + + igniteCfgPath = line.getOptionValue(CONFIG, DEFAULT_CONFIG); + } + catch (ParseException e) { + e.printStackTrace(); + return; + } + + if (!getMNIST(mnistPaths.values())) { + System.out.println(">>> You should have MNIST dataset in " + MNIST_DIR + " to run this example."); + return; + } + + trainingImagesPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + + mnistPaths.get(MNIST_TRAIN_IMAGES))).getPath(); + trainingLabelsPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + + mnistPaths.get(MNIST_TRAIN_LABELS))).getPath(); + testImagesPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + + mnistPaths.get(MNIST_TEST_IMAGES))).getPath(); + testLabelsPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + + mnistPaths.get(MNIST_TEST_LABELS))).getPath(); + + try (Ignite ignite = Ignition.start(igniteCfgPath)) { + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + + int ptsCnt = 60000; + int featCnt = 28 * 28; + + Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnist(trainingImagesPath, trainingLabelsPath, + new Random(123L), ptsCnt); + + Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnist(testImagesPath, testLabelsPath, + new Random(123L), 10_000); + + IgniteCache<BiIndex, Double> cache = createBiIndexedCache(ignite); + + loadVectorsIntoBiIndexedCache(cache.getName(), trainingMnistStream.iterator(), featCnt + 1, ignite); + + ColumnDecisionTreeTrainer<GiniSplitCalculator.GiniData> trainer = new ColumnDecisionTreeTrainer<>(10, + ContinuousSplitCalculators.GINI.apply(ignite), + RegionCalculators.GINI, + RegionCalculators.MOST_COMMON, + ignite); + + System.out.println(">>> Training started"); + long before = System.currentTimeMillis(); + DecisionTreeModel mdl = trainer.train(new BiIndexedCacheColumnDecisionTreeTrainerInput(cache, new HashMap<>(), ptsCnt, featCnt)); + System.out.println(">>> Training finished in " + (System.currentTimeMillis() - before)); + + IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = + Estimators.errorsPercentage(); + + Double accuracy = mse.apply(mdl, testMnistStream.map(v -> + new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity()); + + System.out.println(">>> Errs percentage: " + accuracy); + } + catch (IOException e) { + e.printStackTrace(); + } + + System.out.println(">>> Decision trees example finished."); + } + + /** + * Get MNIST dataset. Value of predicate 'MNIST dataset is present in expected folder' is returned. + * + * @param mnistFileNames File names of MNIST dataset. + * @return Value of predicate 'MNIST dataset is present in expected folder'. + * @throws IOException In case of file system errors. + */ + private static boolean getMNIST(Collection<String> mnistFileNames) throws IOException { + List<String> missing = mnistFileNames.stream(). + filter(f -> IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + f) == null). + collect(Collectors.toList()); + + if (!missing.isEmpty()) { + System.out.println(">>> You have not fully downloaded MNIST dataset in directory " + MNIST_DIR + + ", do you want it to be downloaded? [y]/n"); + Scanner s = new Scanner(System.in); + String str = s.nextLine(); + + if (!str.isEmpty() && !str.toLowerCase().equals("y")) + return false; + } + + for (String s : missing) { + String f = s + ".gz"; + System.out.println(">>> Downloading " + f + "..."); + URL website = new URL("http://yann.lecun.com/exdb/mnist/" + f); + ReadableByteChannel rbc = Channels.newChannel(website.openStream()); + FileOutputStream fos = new FileOutputStream(MNIST_DIR + "/" + f); + fos.getChannel().transferFrom(rbc, 0, Long.MAX_VALUE); + System.out.println(">>> Done."); + + System.out.println(">>> Unzipping " + f + "..."); + unzip(MNIST_DIR + "/" + f, MNIST_DIR + "/" + s); + + System.out.println(">>> Deleting gzip " + f + ", status: " + + Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + f)).delete()); + + System.out.println(">>> Done."); + } + + return true; + } + + /** + * Unzip file located in {@code input} to {@code output}. + * + * @param input Input file path. + * @param output Output file path. + * @throws IOException In case of file system errors. + */ + private static void unzip(String input, String output) throws IOException { + byte[] buf = new byte[1024]; + + try (GZIPInputStream gis = new GZIPInputStream(new FileInputStream(input)); + FileOutputStream out = new FileOutputStream(output)) { + int sz; + while ((sz = gis.read(buf)) > 0) + out.write(buf, 0, sz); + } + } + + /** + * Build cli options. + */ + @NotNull private static Options buildOptions() { + Options options = new Options(); + + Option cfgOpt = OptionBuilder + .withArgName(CONFIG) + .withLongOpt(CONFIG) + .hasArg() + .withDescription("Path to the config.") + .isRequired(false).create(); + + Option unattended = OptionBuilder + .withArgName(MLExamplesCommonArgs.UNATTENDED) + .withLongOpt(MLExamplesCommonArgs.UNATTENDED) + .withDescription("Is example run unattended.") + .isRequired(false).create(); + + options.addOption(cfgOpt); + options.addOption(unattended); + + return options; + } + + /** + * Creates cache where data for training is stored. + * + * @param ignite Ignite instance. + * @return cache where data for training is stored. + */ + private static IgniteCache<BiIndex, Double> createBiIndexedCache(Ignite ignite) { + CacheConfiguration<BiIndex, Double> cfg = new CacheConfiguration<>(); + + // Write to primary. + cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC); + + // No copying of values. + cfg.setCopyOnRead(false); + + cfg.setName("TMP_BI_INDEXED_CACHE"); + + return ignite.getOrCreateCache(cfg); + } + + /** + * Loads vectors into cache. + * + * @param cacheName Name of cache. + * @param vectorsIter Iterator over vectors to load. + * @param vectorSize Size of vector. + * @param ignite Ignite instance. + */ + private static void loadVectorsIntoBiIndexedCache(String cacheName, Iterator<? extends Vector> vectorsIter, + int vectorSize, Ignite ignite) { + try (IgniteDataStreamer<BiIndex, Double> streamer = + ignite.dataStreamer(cacheName)) { + int sampleIdx = 0; + + streamer.perNodeBufferSize(10000); + + while (vectorsIter.hasNext()) { + org.apache.ignite.ml.math.Vector next = vectorsIter.next(); + + for (int i = 0; i < vectorSize; i++) + streamer.addData(new BiIndex(sampleIdx, i), next.getX(i)); + + sampleIdx++; + + if (sampleIdx % 1000 == 0) + System.out.println(">>> Loaded " + sampleIdx + " vectors."); + } + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/a9d40a70/examples/src/main/java/org/apache/ignite/examples/ml/trees/MNISTExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/trees/MNISTExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/trees/MNISTExample.java deleted file mode 100644 index 6ff121e..0000000 --- a/examples/src/main/java/org/apache/ignite/examples/ml/trees/MNISTExample.java +++ /dev/null @@ -1,261 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.examples.ml.trees; - -import java.io.IOException; -import java.util.HashMap; -import java.util.Iterator; -import java.util.Random; -import java.util.function.Function; -import java.util.stream.Stream; -import org.apache.commons.cli.BasicParser; -import org.apache.commons.cli.CommandLine; -import org.apache.commons.cli.CommandLineParser; -import org.apache.commons.cli.Option; -import org.apache.commons.cli.OptionBuilder; -import org.apache.commons.cli.Options; -import org.apache.commons.cli.ParseException; -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.IgniteDataStreamer; -import org.apache.ignite.Ignition; -import org.apache.ignite.cache.CacheAtomicityMode; -import org.apache.ignite.cache.CacheMode; -import org.apache.ignite.cache.CacheWriteSynchronizationMode; -import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.examples.ExampleNodeStartup; -import org.apache.ignite.internal.util.IgniteUtils; -import org.apache.ignite.lang.IgniteBiTuple; -import org.apache.ignite.ml.Model; -import org.apache.ignite.ml.estimators.Estimators; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.math.functions.IgniteTriFunction; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; -import org.apache.ignite.ml.trees.models.DecisionTreeModel; -import org.apache.ignite.ml.trees.trainers.columnbased.BiIndex; -import org.apache.ignite.ml.trees.trainers.columnbased.BiIndexedCacheColumnDecisionTreeTrainerInput; -import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer; -import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.ContinuousSplitCalculators; -import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator; -import org.apache.ignite.ml.trees.trainers.columnbased.regcalcs.RegionCalculators; -import org.apache.ignite.ml.util.MnistUtils; -import org.jetbrains.annotations.NotNull; - -/** - * <p> - * Example of usage of decision trees algorithm for MNIST dataset - * (it can be found here: http://yann.lecun.com/exdb/mnist/). </p> - * <p> - * Remote nodes should always be started with special configuration file which - * enables P2P class loading: {@code 'ignite.{sh|bat} examples/config/example-ignite.xml'}.</p> - * <p> - * Alternatively you can run {@link ExampleNodeStartup} in another JVM which will start node - * with {@code examples/config/example-ignite.xml} configuration.</p> - * <p> - * It is recommended to start at least one node prior to launching this example if you intend - * to run it with default memory settings.</p> - * <p> - * This example should with program arguments, for example - * -ts_i /path/to/train-images-idx3-ubyte - * -ts_l /path/to/train-labels-idx1-ubyte - * -tss_i /path/to/t10k-images-idx3-ubyte - * -tss_l /path/to/t10k-labels-idx1-ubyte - * -cfg examples/config/example-ignite.xml.</p> - * <p> - * -ts_i specifies path to training set images of MNIST; - * -ts_l specifies path to training set labels of MNIST; - * -tss_i specifies path to test set images of MNIST; - * -tss_l specifies path to test set labels of MNIST; - * -cfg specifies path to a config path.</p> - */ -public class MNISTExample { - /** Name of parameter specifying path to training set images. */ - private static final String MNIST_TRAINING_IMAGES_PATH = "ts_i"; - - /** Name of parameter specifying path to training set labels. */ - private static final String MNIST_TRAINING_LABELS_PATH = "ts_l"; - - /** Name of parameter specifying path to test set images. */ - private static final String MNIST_TEST_IMAGES_PATH = "tss_i"; - - /** Name of parameter specifying path to test set labels. */ - private static final String MNIST_TEST_LABELS_PATH = "tss_l"; - - /** Name of parameter specifying path of Ignite config. */ - private static final String CONFIG = "cfg"; - - /** Default config path. */ - private static final String DEFAULT_CONFIG = "examples/config/example-ignite.xml"; - - /** - * Launches example. - * - * @param args Program arguments. - */ - public static void main(String[] args) { - String igniteCfgPath; - - CommandLineParser parser = new BasicParser(); - - String trainingImagesPath; - String trainingLabelsPath; - - String testImagesPath; - String testLabelsPath; - - try { - // Parse the command line arguments. - CommandLine line = parser.parse(buildOptions(), args); - - trainingImagesPath = line.getOptionValue(MNIST_TRAINING_IMAGES_PATH); - trainingLabelsPath = line.getOptionValue(MNIST_TRAINING_LABELS_PATH); - testImagesPath = line.getOptionValue(MNIST_TEST_IMAGES_PATH); - testLabelsPath = line.getOptionValue(MNIST_TEST_LABELS_PATH); - igniteCfgPath = line.getOptionValue(CONFIG, DEFAULT_CONFIG); - } - catch (ParseException e) { - e.printStackTrace(); - return; - } - - try (Ignite ignite = Ignition.start(igniteCfgPath)) { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - - int ptsCnt = 60000; - int featCnt = 28 * 28; - - Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnist(trainingImagesPath, trainingLabelsPath, new Random(123L), ptsCnt); - Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnist(testImagesPath, testLabelsPath, new Random(123L), 10_000); - - IgniteCache<BiIndex, Double> cache = createBiIndexedCache(ignite); - - loadVectorsIntoBiIndexedCache(cache.getName(), trainingMnistStream.iterator(), featCnt + 1, ignite); - - ColumnDecisionTreeTrainer<GiniSplitCalculator.GiniData> trainer = - new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MOST_COMMON, ignite); - - System.out.println(">>> Training started"); - long before = System.currentTimeMillis(); - DecisionTreeModel mdl = trainer.train(new BiIndexedCacheColumnDecisionTreeTrainerInput(cache, new HashMap<>(), ptsCnt, featCnt)); - System.out.println(">>> Training finished in " + (System.currentTimeMillis() - before)); - - IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.errorsPercentage(); - Double accuracy = mse.apply(mdl, testMnistStream.map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity()); - System.out.println(">>> Errs percentage: " + accuracy); - } - catch (IOException e) { - e.printStackTrace(); - } - } - - /** - * Build cli options. - */ - @NotNull private static Options buildOptions() { - Options options = new Options(); - - Option trsImagesPathOpt = OptionBuilder.withArgName(MNIST_TRAINING_IMAGES_PATH).withLongOpt(MNIST_TRAINING_IMAGES_PATH).hasArg() - .withDescription("Path to the MNIST training set.") - .isRequired(true).create(); - - Option trsLabelsPathOpt = OptionBuilder.withArgName(MNIST_TRAINING_LABELS_PATH).withLongOpt(MNIST_TRAINING_LABELS_PATH).hasArg() - .withDescription("Path to the MNIST training set.") - .isRequired(true).create(); - - Option tssImagesPathOpt = OptionBuilder.withArgName(MNIST_TEST_IMAGES_PATH).withLongOpt(MNIST_TEST_IMAGES_PATH).hasArg() - .withDescription("Path to the MNIST test set.") - .isRequired(true).create(); - - Option tssLabelsPathOpt = OptionBuilder.withArgName(MNIST_TEST_LABELS_PATH).withLongOpt(MNIST_TEST_LABELS_PATH).hasArg() - .withDescription("Path to the MNIST test set.") - .isRequired(true).create(); - - Option configOpt = OptionBuilder.withArgName(CONFIG).withLongOpt(CONFIG).hasArg() - .withDescription("Path to the config.") - .isRequired(false).create(); - - options.addOption(trsImagesPathOpt); - options.addOption(trsLabelsPathOpt); - options.addOption(tssImagesPathOpt); - options.addOption(tssLabelsPathOpt); - options.addOption(configOpt); - - return options; - } - - /** - * Creates cache where data for training is stored. - * - * @param ignite Ignite instance. - * @return cache where data for training is stored. - */ - private static IgniteCache<BiIndex, Double> createBiIndexedCache(Ignite ignite) { - CacheConfiguration<BiIndex, Double> cfg = new CacheConfiguration<>(); - - // Write to primary. - cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC); - - // Atomic transactions only. - cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC); - - // No eviction. - cfg.setEvictionPolicy(null); - - // No copying of values. - cfg.setCopyOnRead(false); - - // Cache is partitioned. - cfg.setCacheMode(CacheMode.PARTITIONED); - - cfg.setBackups(0); - - cfg.setName("TMP_BI_INDEXED_CACHE"); - - return ignite.getOrCreateCache(cfg); - } - - /** - * Loads vectors into cache. - * - * @param cacheName Name of cache. - * @param vectorsIterator Iterator over vectors to load. - * @param vectorSize Size of vector. - * @param ignite Ignite instance. - */ - private static void loadVectorsIntoBiIndexedCache(String cacheName, Iterator<? extends Vector> vectorsIterator, - int vectorSize, Ignite ignite) { - try (IgniteDataStreamer<BiIndex, Double> streamer = - ignite.dataStreamer(cacheName)) { - int sampleIdx = 0; - - streamer.perNodeBufferSize(10000); - - while (vectorsIterator.hasNext()) { - org.apache.ignite.ml.math.Vector next = vectorsIterator.next(); - - for (int i = 0; i < vectorSize; i++) - streamer.addData(new BiIndex(sampleIdx, i), next.getX(i)); - - sampleIdx++; - - if (sampleIdx % 1000 == 0) - System.out.println("Loaded " + sampleIdx + " vectors."); - } - } - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/a9d40a70/examples/src/test/java/org/apache/ignite/testsuites/IgniteExamplesMLTestSuite.java ---------------------------------------------------------------------- diff --git a/examples/src/test/java/org/apache/ignite/testsuites/IgniteExamplesMLTestSuite.java b/examples/src/test/java/org/apache/ignite/testsuites/IgniteExamplesMLTestSuite.java index d2f40e6..df85f1a 100644 --- a/examples/src/test/java/org/apache/ignite/testsuites/IgniteExamplesMLTestSuite.java +++ b/examples/src/test/java/org/apache/ignite/testsuites/IgniteExamplesMLTestSuite.java @@ -30,6 +30,7 @@ import javassist.CtClass; import javassist.CtNewMethod; import javassist.NotFoundException; import junit.framework.TestSuite; +import org.apache.ignite.examples.ml.MLExamplesCommonArgs; import org.apache.ignite.testframework.GridTestUtils; import org.apache.ignite.testframework.junits.common.GridAbstractExamplesTest; @@ -85,8 +86,8 @@ public class IgniteExamplesMLTestSuite extends TestSuite { cl.addMethod(CtNewMethod.make("public void testExample() { " + exampleCls.getCanonicalName() + ".main(" - + GridAbstractExamplesTest.class.getName() - + ".EMPTY_ARGS); }", cl)); + + MLExamplesCommonArgs.class.getName() + + ".EMPTY_ARGS_ML); }", cl)); return cl.toClass(); }