[ https://issues.apache.org/jira/browse/IGNITE-12331?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel ]
Alexey Zinoviev updated IGNITE-12331: ------------------------------------- Affects Version/s: (was: 3.0) 2.8 > [ML] ML Preprocessing doesn't work on SQL Tables > ------------------------------------------------ > > Key: IGNITE-12331 > URL: https://issues.apache.org/jira/browse/IGNITE-12331 > Project: Ignite > Issue Type: Bug > Components: ml > Affects Versions: 2.8 > Reporter: Alexey Zinoviev > Assignee: Alexey Zinoviev > Priority: Major > Fix For: 3.0 > > > {code:java} > /* > * Licensed to the Apache Software Foundation (ASF) under one or more > * contributor license agreements. See the NOTICE file distributed with > * this work for additional information regarding copyright ownership. > * The ASF licenses this file to You under the Apache License, Version 2.0 > * (the "License"); you may not use this file except in compliance with > * the License. You may obtain a copy of the License at > * > * http://www.apache.org/licenses/LICENSE-2.0 > * > * Unless required by applicable law or agreed to in writing, software > * distributed under the License is distributed on an "AS IS" BASIS, > * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. > * See the License for the specific language governing permissions and > * limitations under the License. > */ > package org.apache.ignite.examples.ml.tutorial.sql; > import java.util.List; > import org.apache.ignite.Ignite; > import org.apache.ignite.IgniteCache; > import org.apache.ignite.Ignition; > import org.apache.ignite.cache.query.QueryCursor; > import org.apache.ignite.cache.query.SqlFieldsQuery; > import org.apache.ignite.configuration.CacheConfiguration; > import org.apache.ignite.internal.util.IgniteUtils; > import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer; > import > org.apache.ignite.ml.dataset.feature.extractor.impl.BinaryObjectVectorizer; > import org.apache.ignite.ml.math.primitives.vector.Vector; > import org.apache.ignite.ml.math.primitives.vector.VectorUtils; > import org.apache.ignite.ml.preprocessing.Preprocessor; > import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; > import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; > import org.apache.ignite.ml.sql.SqlDatasetBuilder; > import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; > import org.apache.ignite.ml.tree.DecisionTreeNode; > /** > * Example of using distributed {@link DecisionTreeClassificationTrainer} on > a data stored in SQL table. > */ > public class PreprocessingAndTrainingSQLTableExample { > /** > * Dummy cache name. > */ > private static final String DUMMY_CACHE_NAME = "dummy_cache"; > /** > * Training data. > */ > private static final String TRAIN_DATA_RES = > "examples/src/main/resources/datasets/titanic_train.csv"; > /** > * Test data. > */ > private static final String TEST_DATA_RES = > "examples/src/main/resources/datasets/titanic_test.csv"; > /** > * Run example. > */ > public static void main(String[] args) { > System.out.println(">>> Decision tree classification trainer example > started."); > // Start ignite grid. > try (Ignite ignite = > Ignition.start("examples/config/example-ignite.xml")) { > System.out.println(">>> Ignite grid started."); > // Dummy cache is required to perform SQL queries. > CacheConfiguration<?, ?> cacheCfg = new > CacheConfiguration<>(DUMMY_CACHE_NAME) > .setSqlSchema("PUBLIC"); > IgniteCache<?, ?> cache = null; > try { > cache = ignite.getOrCreateCache(cacheCfg); > System.out.println(">>> Creating table with training > data..."); > cache.query(new SqlFieldsQuery("create table titanic_train > (\n" + > " passengerid int primary key,\n" + > " survived int,\n" + > " pclass int,\n" + > " name varchar(255),\n" + > " sex varchar(255),\n" + > " age float,\n" + > " sibsp int,\n" + > " parch int,\n" + > " ticket varchar(255),\n" + > " fare float,\n" + > " cabin varchar(255),\n" + > " embarked varchar(255)\n" + > ") with \"template=partitioned\";")).getAll(); > System.out.println(">>> Filling training data..."); > cache.query(new SqlFieldsQuery("insert into titanic_train > select * from csvread('" + > > IgniteUtils.resolveIgnitePath(TRAIN_DATA_RES).getAbsolutePath() + > "')")).getAll(); > System.out.println(">>> Creating table with test data..."); > cache.query(new SqlFieldsQuery("create table titanic_test > (\n" + > " passengerid int primary key,\n" + > " pclass int,\n" + > " name varchar(255),\n" + > " sex varchar(255),\n" + > " age float,\n" + > " sibsp int,\n" + > " parch int,\n" + > " ticket varchar(255),\n" + > " fare float,\n" + > " cabin varchar(255),\n" + > " embarked varchar(255)\n" + > ") with \"template=partitioned\";")).getAll(); > System.out.println(">>> Filling training data..."); > cache.query(new SqlFieldsQuery("insert into titanic_test > select * from csvread('" + > > IgniteUtils.resolveIgnitePath(TEST_DATA_RES).getAbsolutePath() + > "')")).getAll(); > System.out.println(">>> Prepare trainer..."); > DecisionTreeClassificationTrainer trainer = new > DecisionTreeClassificationTrainer(4, 0); > System.out.println(">>> Perform training..."); > Vectorizer vectorizer = new > BinaryObjectVectorizer<>("pclass", "age", "sibsp", "parch", "fare") > .withFeature("sex", > BinaryObjectVectorizer.Mapping.create().map("male", 1.0).defaultValue(0.0)) > .labeled("survived"); > Preprocessor minMaxScalerPreprocessor = new > MinMaxScalerTrainer() > .fit( > ignite, > cache, > vectorizer > ); > Preprocessor normalizationPreprocessor = new > NormalizationTrainer() > .withP(1) > .fit( > ignite, > cache, > minMaxScalerPreprocessor > ); > DecisionTreeNode mdl = trainer.fit( > new SqlDatasetBuilder(ignite, "SQL_PUBLIC_TITANIK_TRAIN"), > normalizationPreprocessor > ); > System.out.println(">>> Perform inference..."); > try (QueryCursor<List<?>> cursor = cache.query(new > SqlFieldsQuery("select " + > "pclass, " + > "sex, " + > "age, " + > "sibsp, " + > "parch, " + > "fare from titanic_test"))) { > for (List<?> passenger : cursor) { > Vector input = VectorUtils.of(new Double[] { > asDouble(passenger.get(0)), > "male".equals(passenger.get(1)) ? 1.0 : 0.0, > asDouble(passenger.get(2)), > asDouble(passenger.get(3)), > asDouble(passenger.get(4)), > asDouble(passenger.get(5)) > }); > double prediction = mdl.predict(input); > System.out.printf("Passenger %s will %s.\n", > passenger, prediction == 0 ? "die" : "survive"); > } > } > System.out.println(">>> Example completed."); > } > finally { > cache.query(new SqlFieldsQuery("DROP TABLE titanic_train")); > cache.query(new SqlFieldsQuery("DROP TABLE titanic_test")); > cache.destroy(); > } > } > finally { > System.out.flush(); > } > } > /** > * Converts specified number into double. > * > * @param obj Number. > * @param <T> Type of number. > * @return Double. > */ > private static <T extends Number> Double asDouble(Object obj) { > if (obj == null) > return null; > if (obj instanceof Number) { > Number num = (Number)obj; > return num.doubleValue(); > } > throw new IllegalArgumentException("Object is expected to be a number > [obj=" + obj + "]"); > } > } > {code} -- This message was sent by Atlassian Jira (v8.3.4#803005)