IGNITE-10289: [ML] Import models from XGBoost This closes #5533
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/4ae29fca Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/4ae29fca Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/4ae29fca Branch: refs/heads/ignite-10639 Commit: 4ae29fca74e814321c8448e00e9d3fdcd54733aa Parents: ece5869 Author: Anton Dmitriev <[email protected]> Authored: Tue Dec 18 18:31:26 2018 +0300 Committer: Yury Babak <[email protected]> Committed: Tue Dec 18 18:31:26 2018 +0300 ---------------------------------------------------------------------- examples/pom.xml | 6 + .../TensorFlowDistributedInferenceExample.java | 2 +- .../TensorFlowLocalInferenceExample.java | 2 +- .../TensorFlowThreadedInferenceExample.java | 2 +- .../ml/xgboost/XGBoostModelParserExample.java | 99 ++ .../examples/ml/xgboost/package-info.java | 22 + .../resources/datasets/agaricus-test-data.txt | 1611 ++++++++++++++++++ .../datasets/agaricus-test-expected-results.txt | 1611 ++++++++++++++++++ .../resources/ml/mnist_tf_model/saved_model.pb | Bin 37185 -> 0 bytes .../variables/variables.data-00000-of-00001 | Bin 13098544 -> 0 bytes .../ml/mnist_tf_model/variables/variables.index | Bin 410 -> 0 bytes .../models/mnist_tf_model/saved_model.pb | Bin 0 -> 37185 bytes .../variables/variables.data-00000-of-00001 | Bin 0 -> 13098544 bytes .../mnist_tf_model/variables/variables.index | Bin 0 -> 410 bytes .../resources/models/xgboost/agaricus-model.txt | 714 ++++++++ .../IgniteDistributedInfModelBuilder.java | 7 +- modules/ml/xgboost-model-parser/pom.xml | 55 + .../ignite/ml/xgboost/MapBasedXGObject.java | 61 + .../apache/ignite/ml/xgboost/XGLeafNode.java | 38 + .../org/apache/ignite/ml/xgboost/XGModel.java | 53 + .../org/apache/ignite/ml/xgboost/XGNode.java | 29 + .../org/apache/ignite/ml/xgboost/XGObject.java | 33 + .../apache/ignite/ml/xgboost/XGSplitNode.java | 74 + .../apache/ignite/ml/xgboost/package-info.java | 22 + .../xgboost/parser/XGBoostModelBaseVisitor.java | 78 + .../ml/xgboost/parser/XGBoostModelLexer.java | 210 +++ .../ml/xgboost/parser/XGBoostModelListener.java | 98 ++ .../ml/xgboost/parser/XGBoostModelParser.java | 966 +++++++++++ .../ml/xgboost/parser/XGBoostModelVisitor.java | 71 + .../ignite/ml/xgboost/parser/XGModelParser.java | 87 + .../ignite/ml/xgboost/parser/package-info.java | 22 + .../xgboost/parser/visitor/XGModelVisitor.java | 43 + .../xgboost/parser/visitor/XGTreeVisitor.java | 82 + .../ml/xgboost/parser/visitor/package-info.java | 22 + .../ml/xgboost/IgniteMLXGBoostTestSuite.java | 31 + .../xgboost/parser/XGBoostModelParserTest.java | 85 + .../test/resources/datasets/agaricus-model.txt | 714 ++++++++ .../resources/datasets/agaricus-test-data.txt | 1611 ++++++++++++++++++ .../datasets/agaricus-test-expected-results.txt | 1611 ++++++++++++++++++ parent/pom.xml | 2 +- pom.xml | 1 + 41 files changed, 10168 insertions(+), 7 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/4ae29fca/examples/pom.xml ---------------------------------------------------------------------- diff --git a/examples/pom.xml b/examples/pom.xml index c6b0a5f..44dab95 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -84,6 +84,12 @@ </dependency> <dependency> + <groupId>org.apache.ignite</groupId> + <artifactId>ignite-ml-xgboost-model-parser</artifactId> + <version>${project.version}</version> + </dependency> + + <dependency> <groupId>commons-cli</groupId> <artifactId>commons-cli</artifactId> <version>1.2</version> http://git-wip-us.apache.org/repos/asf/ignite/blob/4ae29fca/examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowDistributedInferenceExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowDistributedInferenceExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowDistributedInferenceExample.java index ec8cac6..48e8df1 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowDistributedInferenceExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowDistributedInferenceExample.java @@ -44,7 +44,7 @@ import org.tensorflow.Tensor; */ public class TensorFlowDistributedInferenceExample { /** Path to the directory with saved TensorFlow model. */ - private static final String MODEL_PATH = "examples/src/main/resources/ml/mnist_tf_model"; + private static final String MODEL_PATH = "examples/src/main/resources/models/mnist_tf_model"; /** Path to the MNIST images data. */ private static final String MNIST_IMG_PATH = "examples/src/main/resources/datasets/t10k-images-idx3-ubyte"; http://git-wip-us.apache.org/repos/asf/ignite/blob/4ae29fca/examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowLocalInferenceExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowLocalInferenceExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowLocalInferenceExample.java index 0e79856..c907778 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowLocalInferenceExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowLocalInferenceExample.java @@ -38,7 +38,7 @@ import org.tensorflow.Tensor; */ public class TensorFlowLocalInferenceExample { /** Path to the directory with saved TensorFlow model. */ - private static final String MODEL_PATH = "examples/src/main/resources/ml/mnist_tf_model"; + private static final String MODEL_PATH = "examples/src/main/resources/models/mnist_tf_model"; /** Path to the MNIST images data. */ private static final String MNIST_IMG_PATH = "examples/src/main/resources/datasets/t10k-images-idx3-ubyte"; http://git-wip-us.apache.org/repos/asf/ignite/blob/4ae29fca/examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowThreadedInferenceExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowThreadedInferenceExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowThreadedInferenceExample.java index 002e5ae..93dadea 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowThreadedInferenceExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/TensorFlowThreadedInferenceExample.java @@ -42,7 +42,7 @@ import org.tensorflow.Tensor; */ public class TensorFlowThreadedInferenceExample { /** Path to the directory with saved TensorFlow model. */ - private static final String MODEL_PATH = "examples/src/main/resources/ml/mnist_tf_model"; + private static final String MODEL_PATH = "examples/src/main/resources/models/mnist_tf_model"; /** Path to the MNIST images data. */ private static final String MNIST_IMG_PATH = "examples/src/main/resources/datasets/t10k-images-idx3-ubyte"; http://git-wip-us.apache.org/repos/asf/ignite/blob/4ae29fca/examples/src/main/java/org/apache/ignite/examples/ml/xgboost/XGBoostModelParserExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/xgboost/XGBoostModelParserExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/xgboost/XGBoostModelParserExample.java new file mode 100644 index 0000000..40f10d8 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/xgboost/XGBoostModelParserExample.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.xgboost; + +import java.io.File; +import java.io.FileNotFoundException; +import java.util.Scanner; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import org.apache.ignite.Ignite; +import org.apache.ignite.Ignition; +import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.ml.inference.InfModel; +import org.apache.ignite.ml.inference.builder.AsyncInfModelBuilder; +import org.apache.ignite.ml.inference.builder.IgniteDistributedInfModelBuilder; +import org.apache.ignite.ml.inference.reader.FileSystemInfModelReader; +import org.apache.ignite.ml.inference.reader.InfModelReader; +import org.apache.ignite.ml.xgboost.MapBasedXGObject; +import org.apache.ignite.ml.xgboost.XGObject; +import org.apache.ignite.ml.xgboost.parser.XGModelParser; + +/** + * This example demonstrates how to import XGBoost model and use imported model for distributed inference in Apache + * Ignite. + */ +public class XGBoostModelParserExample { + /** Test model resource name. */ + private static final String TEST_MODEL_RES = "examples/src/main/resources/models/xgboost/agaricus-model.txt"; + + /** Test data. */ + private static final String TEST_DATA_RES = "examples/src/main/resources/datasets/agaricus-test-data.txt"; + + /** Test expected results. */ + private static final String TEST_ER_RES = "examples/src/main/resources/datasets/agaricus-test-expected-results.txt"; + + /** Parser. */ + private static final XGModelParser parser = new XGModelParser(); + + /** Run example. */ + public static void main(String... args) throws ExecutionException, InterruptedException, FileNotFoundException { + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + File mdlRsrc = IgniteUtils.resolveIgnitePath(TEST_MODEL_RES); + if (mdlRsrc == null) + throw new IllegalArgumentException("File not found [resource_path=" + TEST_MODEL_RES + "]"); + + InfModelReader reader = new FileSystemInfModelReader(mdlRsrc.getPath()); + + AsyncInfModelBuilder mdlBuilder = new IgniteDistributedInfModelBuilder(ignite, 4, 4); + + File testData = IgniteUtils.resolveIgnitePath(TEST_DATA_RES); + if (testData == null) + throw new IllegalArgumentException("File not found [resource_path=" + TEST_DATA_RES + "]"); + + File testExpRes = IgniteUtils.resolveIgnitePath(TEST_ER_RES); + if (testExpRes == null) + throw new IllegalArgumentException("File not found [resource_path=" + TEST_ER_RES + "]"); + + try (InfModel<XGObject, Future<Double>> mdl = mdlBuilder.build(reader, parser); + Scanner testDataScanner = new Scanner(testData); + Scanner testExpResultsScanner = new Scanner(testExpRes)) { + + while (testDataScanner.hasNextLine()) { + String testDataStr = testDataScanner.nextLine(); + String testExpResultsStr = testExpResultsScanner.nextLine(); + + MapBasedXGObject testObj = new MapBasedXGObject(); + + for (String keyValueString : testDataStr.split(" ")) { + String[] keyVal = keyValueString.split(":"); + + if (keyVal.length == 2) + testObj.put("f" + keyVal[0], Double.parseDouble(keyVal[1])); + } + + double prediction = mdl.predict(testObj).get(); + + double expPrediction = Double.parseDouble(testExpResultsStr); + + System.out.println("Expected: " + expPrediction + ", prediction: " + prediction); + } + } + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/4ae29fca/examples/src/main/java/org/apache/ignite/examples/ml/xgboost/package-info.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/xgboost/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/xgboost/package-info.java new file mode 100644 index 0000000..3f14631 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/xgboost/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * XGBoost model inference examples. + */ +package org.apache.ignite.examples.ml.xgboost; \ No newline at end of file
