You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@ignite.apache.org by "Alexey Zinoviev (Jira)" <ji...@apache.org> on 2019/10/28 15:55:00 UTC

[jira] [Updated] (IGNITE-12331) [ML] ML Preprocessing doesn't work on SQL Tables

     [ https://issues.apache.org/jira/browse/IGNITE-12331?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel ]

Alexey Zinoviev updated IGNITE-12331:
-------------------------------------
    Affects Version/s: 3.0

> [ML] ML Preprocessing doesn't work on SQL Tables
> ------------------------------------------------
>
>                 Key: IGNITE-12331
>                 URL: https://issues.apache.org/jira/browse/IGNITE-12331
>             Project: Ignite
>          Issue Type: Bug
>          Components: ml
>    Affects Versions: 3.0
>            Reporter: Alexey Zinoviev
>            Assignee: Alexey Zinoviev
>            Priority: Major
>
> {code:java}
> /*
>  * Licensed to the Apache Software Foundation (ASF) under one or more
>  * contributor license agreements.  See the NOTICE file distributed with
>  * this work for additional information regarding copyright ownership.
>  * The ASF licenses this file to You under the Apache License, Version 2.0
>  * (the "License"); you may not use this file except in compliance with
>  * the License.  You may obtain a copy of the License at
>  *
>  *      http://www.apache.org/licenses/LICENSE-2.0
>  *
>  * Unless required by applicable law or agreed to in writing, software
>  * distributed under the License is distributed on an "AS IS" BASIS,
>  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
>  * See the License for the specific language governing permissions and
>  * limitations under the License.
>  */
> package org.apache.ignite.examples.ml.tutorial.sql;
> import java.util.List;
> import org.apache.ignite.Ignite;
> import org.apache.ignite.IgniteCache;
> import org.apache.ignite.Ignition;
> import org.apache.ignite.cache.query.QueryCursor;
> import org.apache.ignite.cache.query.SqlFieldsQuery;
> import org.apache.ignite.configuration.CacheConfiguration;
> import org.apache.ignite.internal.util.IgniteUtils;
> import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
> import org.apache.ignite.ml.dataset.feature.extractor.impl.BinaryObjectVectorizer;
> import org.apache.ignite.ml.math.primitives.vector.Vector;
> import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
> import org.apache.ignite.ml.preprocessing.Preprocessor;
> import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer;
> import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
> import org.apache.ignite.ml.sql.SqlDatasetBuilder;
> import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
> import org.apache.ignite.ml.tree.DecisionTreeNode;
> /**
>  * Example of using distributed {@link DecisionTreeClassificationTrainer} on a data stored in SQL table.
>  */
> public class PreprocessingAndTrainingSQLTableExample {
>     /**
>      * Dummy cache name.
>      */
>     private static final String DUMMY_CACHE_NAME = "dummy_cache";
>     /**
>      * Training data.
>      */
>     private static final String TRAIN_DATA_RES = "examples/src/main/resources/datasets/titanic_train.csv";
>     /**
>      * Test data.
>      */
>     private static final String TEST_DATA_RES = "examples/src/main/resources/datasets/titanic_test.csv";
>     /**
>      * Run example.
>      */
>     public static void main(String[] args) {
>         System.out.println(">>> Decision tree classification trainer example started.");
>         // Start ignite grid.
>         try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
>             System.out.println(">>> Ignite grid started.");
>             // Dummy cache is required to perform SQL queries.
>             CacheConfiguration<?, ?> cacheCfg = new CacheConfiguration<>(DUMMY_CACHE_NAME)
>                 .setSqlSchema("PUBLIC");
>             IgniteCache<?, ?> cache = null;
>             try {
>                 cache = ignite.getOrCreateCache(cacheCfg);
>                 System.out.println(">>> Creating table with training data...");
>                 cache.query(new SqlFieldsQuery("create table titanic_train (\n" +
>                     "    passengerid int primary key,\n" +
>                     "    survived int,\n" +
>                     "    pclass int,\n" +
>                     "    name varchar(255),\n" +
>                     "    sex varchar(255),\n" +
>                     "    age float,\n" +
>                     "    sibsp int,\n" +
>                     "    parch int,\n" +
>                     "    ticket varchar(255),\n" +
>                     "    fare float,\n" +
>                     "    cabin varchar(255),\n" +
>                     "    embarked varchar(255)\n" +
>                     ") with \"template=partitioned\";")).getAll();
>                 System.out.println(">>> Filling training data...");
>                 cache.query(new SqlFieldsQuery("insert into titanic_train select * from csvread('" +
>                     IgniteUtils.resolveIgnitePath(TRAIN_DATA_RES).getAbsolutePath() + "')")).getAll();
>                 System.out.println(">>> Creating table with test data...");
>                 cache.query(new SqlFieldsQuery("create table titanic_test (\n" +
>                     "    passengerid int primary key,\n" +
>                     "    pclass int,\n" +
>                     "    name varchar(255),\n" +
>                     "    sex varchar(255),\n" +
>                     "    age float,\n" +
>                     "    sibsp int,\n" +
>                     "    parch int,\n" +
>                     "    ticket varchar(255),\n" +
>                     "    fare float,\n" +
>                     "    cabin varchar(255),\n" +
>                     "    embarked varchar(255)\n" +
>                     ") with \"template=partitioned\";")).getAll();
>                 System.out.println(">>> Filling training data...");
>                 cache.query(new SqlFieldsQuery("insert into titanic_test select * from csvread('" +
>                     IgniteUtils.resolveIgnitePath(TEST_DATA_RES).getAbsolutePath() + "')")).getAll();
>                 System.out.println(">>> Prepare trainer...");
>                 DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
>                 System.out.println(">>> Perform training...");
>                 Vectorizer vectorizer = new BinaryObjectVectorizer<>("pclass", "age", "sibsp", "parch", "fare")
>                     .withFeature("sex", BinaryObjectVectorizer.Mapping.create().map("male", 1.0).defaultValue(0.0))
>                     .labeled("survived");
>                 Preprocessor minMaxScalerPreprocessor = new MinMaxScalerTrainer()
>                     .fit(
>                         ignite,
>                         cache,
>                         vectorizer
>                     );
>                 Preprocessor normalizationPreprocessor = new NormalizationTrainer()
>                     .withP(1)
>                     .fit(
>                         ignite,
>                         cache,
>                         minMaxScalerPreprocessor
>                     );
>                 DecisionTreeNode mdl = trainer.fit(
>                     new SqlDatasetBuilder(ignite, "SQL_PUBLIC_TITANIK_TRAIN"),
>                     normalizationPreprocessor
>                 );
>                 System.out.println(">>> Perform inference...");
>                 try (QueryCursor<List<?>> cursor = cache.query(new SqlFieldsQuery("select " +
>                     "pclass, " +
>                     "sex, " +
>                     "age, " +
>                     "sibsp, " +
>                     "parch, " +
>                     "fare from titanic_test"))) {
>                     for (List<?> passenger : cursor) {
>                         Vector input = VectorUtils.of(new Double[] {
>                             asDouble(passenger.get(0)),
>                             "male".equals(passenger.get(1)) ? 1.0 : 0.0,
>                             asDouble(passenger.get(2)),
>                             asDouble(passenger.get(3)),
>                             asDouble(passenger.get(4)),
>                             asDouble(passenger.get(5))
>                         });
>                         double prediction = mdl.predict(input);
>                         System.out.printf("Passenger %s will %s.\n", passenger, prediction == 0 ? "die" : "survive");
>                     }
>                 }
>                 System.out.println(">>> Example completed.");
>             }
>             finally {
>                 cache.query(new SqlFieldsQuery("DROP TABLE titanic_train"));
>                 cache.query(new SqlFieldsQuery("DROP TABLE titanic_test"));
>                 cache.destroy();
>             }
>         }
>         finally {
>             System.out.flush();
>         }
>     }
>     /**
>      * Converts specified number into double.
>      *
>      * @param obj Number.
>      * @param <T> Type of number.
>      * @return Double.
>      */
>     private static <T extends Number> Double asDouble(Object obj) {
>         if (obj == null)
>             return null;
>         if (obj instanceof Number) {
>             Number num = (Number)obj;
>             return num.doubleValue();
>         }
>         throw new IllegalArgumentException("Object is expected to be a number [obj=" + obj + "]");
>     }
> }
> {code}



--
This message was sent by Atlassian Jira
(v8.3.4#803005)