Alexey Zinoviev created IGNITE-12331: ----------------------------------------
Summary: [ML] ML Preprocessing doesn't work on SQL Tables Key: IGNITE-12331 URL: https://issues.apache.org/jira/browse/IGNITE-12331 Project: Ignite Issue Type: Bug Components: ml Reporter: Alexey Zinoviev Assignee: Alexey Zinoviev {code:java} /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.ignite.examples.ml.tutorial.sql; import java.util.List; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.Ignition; import org.apache.ignite.cache.query.QueryCursor; import org.apache.ignite.cache.query.SqlFieldsQuery; import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.internal.util.IgniteUtils; import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer; import org.apache.ignite.ml.dataset.feature.extractor.impl.BinaryObjectVectorizer; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.preprocessing.Preprocessor; import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; import org.apache.ignite.ml.sql.SqlDatasetBuilder; import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; import org.apache.ignite.ml.tree.DecisionTreeNode; /** * Example of using distributed {@link DecisionTreeClassificationTrainer} on a data stored in SQL table. */ public class PreprocessingAndTrainingSQLTableExample { /** * Dummy cache name. */ private static final String DUMMY_CACHE_NAME = "dummy_cache"; /** * Training data. */ private static final String TRAIN_DATA_RES = "examples/src/main/resources/datasets/titanic_train.csv"; /** * Test data. */ private static final String TEST_DATA_RES = "examples/src/main/resources/datasets/titanic_test.csv"; /** * Run example. */ public static void main(String[] args) { System.out.println(">>> Decision tree classification trainer example started."); // Start ignite grid. try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { System.out.println(">>> Ignite grid started."); // Dummy cache is required to perform SQL queries. CacheConfiguration<?, ?> cacheCfg = new CacheConfiguration<>(DUMMY_CACHE_NAME) .setSqlSchema("PUBLIC"); IgniteCache<?, ?> cache = null; try { cache = ignite.getOrCreateCache(cacheCfg); System.out.println(">>> Creating table with training data..."); cache.query(new SqlFieldsQuery("create table titanic_train (\n" + " passengerid int primary key,\n" + " survived int,\n" + " pclass int,\n" + " name varchar(255),\n" + " sex varchar(255),\n" + " age float,\n" + " sibsp int,\n" + " parch int,\n" + " ticket varchar(255),\n" + " fare float,\n" + " cabin varchar(255),\n" + " embarked varchar(255)\n" + ") with \"template=partitioned\";")).getAll(); System.out.println(">>> Filling training data..."); cache.query(new SqlFieldsQuery("insert into titanic_train select * from csvread('" + IgniteUtils.resolveIgnitePath(TRAIN_DATA_RES).getAbsolutePath() + "')")).getAll(); System.out.println(">>> Creating table with test data..."); cache.query(new SqlFieldsQuery("create table titanic_test (\n" + " passengerid int primary key,\n" + " pclass int,\n" + " name varchar(255),\n" + " sex varchar(255),\n" + " age float,\n" + " sibsp int,\n" + " parch int,\n" + " ticket varchar(255),\n" + " fare float,\n" + " cabin varchar(255),\n" + " embarked varchar(255)\n" + ") with \"template=partitioned\";")).getAll(); System.out.println(">>> Filling training data..."); cache.query(new SqlFieldsQuery("insert into titanic_test select * from csvread('" + IgniteUtils.resolveIgnitePath(TEST_DATA_RES).getAbsolutePath() + "')")).getAll(); System.out.println(">>> Prepare trainer..."); DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0); System.out.println(">>> Perform training..."); Vectorizer vectorizer = new BinaryObjectVectorizer<>("pclass", "age", "sibsp", "parch", "fare") .withFeature("sex", BinaryObjectVectorizer.Mapping.create().map("male", 1.0).defaultValue(0.0)) .labeled("survived"); Preprocessor minMaxScalerPreprocessor = new MinMaxScalerTrainer() .fit( ignite, cache, vectorizer ); Preprocessor normalizationPreprocessor = new NormalizationTrainer() .withP(1) .fit( ignite, cache, minMaxScalerPreprocessor ); DecisionTreeNode mdl = trainer.fit( new SqlDatasetBuilder(ignite, "SQL_PUBLIC_TITANIK_TRAIN"), normalizationPreprocessor ); System.out.println(">>> Perform inference..."); try (QueryCursor<List<?>> cursor = cache.query(new SqlFieldsQuery("select " + "pclass, " + "sex, " + "age, " + "sibsp, " + "parch, " + "fare from titanic_test"))) { for (List<?> passenger : cursor) { Vector input = VectorUtils.of(new Double[] { asDouble(passenger.get(0)), "male".equals(passenger.get(1)) ? 1.0 : 0.0, asDouble(passenger.get(2)), asDouble(passenger.get(3)), asDouble(passenger.get(4)), asDouble(passenger.get(5)) }); double prediction = mdl.predict(input); System.out.printf("Passenger %s will %s.\n", passenger, prediction == 0 ? "die" : "survive"); } } System.out.println(">>> Example completed."); } finally { cache.query(new SqlFieldsQuery("DROP TABLE titanic_train")); cache.query(new SqlFieldsQuery("DROP TABLE titanic_test")); cache.destroy(); } } finally { System.out.flush(); } } /** * Converts specified number into double. * * @param obj Number. * @param <T> Type of number. * @return Double. */ private static <T extends Number> Double asDouble(Object obj) { if (obj == null) return null; if (obj instanceof Number) { Number num = (Number)obj; return num.doubleValue(); } throw new IllegalArgumentException("Object is expected to be a number [obj=" + obj + "]"); } } {code} -- This message was sent by Atlassian Jira (v8.3.4#803005)