You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by za...@apache.org on 2020/06/26 13:39:52 UTC

[ignite] branch master updated: IGNITE-12903 Fixed ML + SQL examples (#7965)

This is an automated email from the ASF dual-hosted git repository.

zaleslaw pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ignite.git


The following commit(s) were added to refs/heads/master by this push:
     new 40377b1  IGNITE-12903 Fixed ML + SQL examples (#7965)
40377b1 is described below

commit 40377b109053d2e576ace7c612bf006aff9ef76d
Author: Alexey Zinoviev <za...@gmail.com>
AuthorDate: Fri Jun 26 16:39:29 2020 +0300

    IGNITE-12903 Fixed ML + SQL examples (#7965)
    
    * [IGNITE-12903] Fixed ML + SQL examples
    
    * [IGNITE-12903] Fixed ML + SQL examples
---
 ...eeClassificationTrainerSQLInferenceExample.java |  36 ++-----
 ...onTreeClassificationTrainerSQLTableExample.java | 109 ++++++++++++++-------
 .../selection/scoring/evaluator/package-info.java  |   4 +-
 3 files changed, 83 insertions(+), 66 deletions(-)

diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLInferenceExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLInferenceExample.java
index ab2a00c..543e211 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLInferenceExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLInferenceExample.java
@@ -17,18 +17,14 @@
 
 package org.apache.ignite.examples.ml.sql;
 
-import java.util.HashSet;
+import java.io.IOException;
 import java.util.List;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
-import org.apache.ignite.IgniteCheckedException;
 import org.apache.ignite.Ignition;
 import org.apache.ignite.cache.query.QueryCursor;
 import org.apache.ignite.cache.query.SqlFieldsQuery;
 import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.internal.IgniteEx;
-import org.apache.ignite.internal.processors.query.h2.IgniteH2Indexing;
-import org.apache.ignite.internal.util.IgniteUtils;
 import org.apache.ignite.ml.dataset.feature.extractor.impl.BinaryObjectVectorizer;
 import org.apache.ignite.ml.inference.IgniteModelStorageUtil;
 import org.apache.ignite.ml.sql.SQLFunctions;
@@ -36,6 +32,8 @@ import org.apache.ignite.ml.sql.SqlDatasetBuilder;
 import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
 import org.apache.ignite.ml.tree.DecisionTreeNode;
 
+import static org.apache.ignite.examples.ml.sql.DecisionTreeClassificationTrainerSQLTableExample.loadTitanicDatasets;
+
 /**
  * Example of using distributed {@link DecisionTreeClassificationTrainer} on a data stored in SQL table and inference
  * made as SQL select query.
@@ -47,30 +45,15 @@ public class DecisionTreeClassificationTrainerSQLInferenceExample {
     private static final String DUMMY_CACHE_NAME = "dummy_cache";
 
     /**
-     * Training data.
-     */
-    private static final String TRAIN_DATA_RES = "examples/src/main/resources/datasets/titanic_train.csv";
-
-    /**
-     * Test data.
-     */
-    private static final String TEST_DATA_RES = "examples/src/main/resources/datasets/titanic_test.csv";
-
-    /**
      * Run example.
      */
-    public static void main(String[] args) throws IgniteCheckedException {
+    public static void main(String[] args) throws IOException {
         System.out.println(">>> Decision tree classification trainer example started.");
 
         // Start ignite grid.
         try (Ignite ignite = Ignition.start("examples/config/example-ignite-ml.xml")) {
             System.out.println(">>> Ignite grid started.");
 
-            // Use internal API to enable SQL functions disabled by default (the function CSVREAD is used below)
-            // TODO: IGNITE-12903
-            ((IgniteH2Indexing)((IgniteEx)ignite).context().query().getIndexing())
-                .distributedConfiguration().disabledFunctions(new HashSet<>());
-
             // Dummy cache is required to perform SQL queries.
             CacheConfiguration<?, ?> cacheCfg = new CacheConfiguration<>(DUMMY_CACHE_NAME)
                 .setSqlSchema("PUBLIC")
@@ -83,8 +66,8 @@ public class DecisionTreeClassificationTrainerSQLInferenceExample {
                 System.out.println(">>> Creating table with training data...");
                 cache.query(new SqlFieldsQuery("create table titanic_train (\n" +
                     "    passengerid int primary key,\n" +
-                    "    survived int,\n" +
                     "    pclass int,\n" +
+                    "    survived int,\n" +
                     "    name varchar(255),\n" +
                     "    sex varchar(255),\n" +
                     "    age float,\n" +
@@ -96,14 +79,11 @@ public class DecisionTreeClassificationTrainerSQLInferenceExample {
                     "    embarked varchar(255)\n" +
                     ") with \"template=partitioned\";")).getAll();
 
-                System.out.println(">>> Filling training data...");
-                cache.query(new SqlFieldsQuery("insert into titanic_train select * from csvread('" +
-                    IgniteUtils.resolveIgnitePath(TRAIN_DATA_RES).getAbsolutePath() + "')")).getAll();
-
                 System.out.println(">>> Creating table with test data...");
                 cache.query(new SqlFieldsQuery("create table titanic_test (\n" +
                     "    passengerid int primary key,\n" +
                     "    pclass int,\n" +
+                    "    survived int,\n" +
                     "    name varchar(255),\n" +
                     "    sex varchar(255),\n" +
                     "    age float,\n" +
@@ -115,9 +95,7 @@ public class DecisionTreeClassificationTrainerSQLInferenceExample {
                     "    embarked varchar(255)\n" +
                     ") with \"template=partitioned\";")).getAll();
 
-                System.out.println(">>> Filling training data...");
-                cache.query(new SqlFieldsQuery("insert into titanic_test select * from csvread('" +
-                    IgniteUtils.resolveIgnitePath(TEST_DATA_RES).getAbsolutePath() + "')")).getAll();
+                loadTitanicDatasets(ignite, cache);
 
                 System.out.println(">>> Prepare trainer...");
                 DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLTableExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLTableExample.java
index 5fe123c..083608e 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLTableExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLTableExample.java
@@ -17,8 +17,9 @@
 
 package org.apache.ignite.examples.ml.sql;
 
-import java.util.HashSet;
+import java.io.IOException;
 import java.util.List;
+
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.IgniteCheckedException;
@@ -26,9 +27,8 @@ import org.apache.ignite.Ignition;
 import org.apache.ignite.cache.query.QueryCursor;
 import org.apache.ignite.cache.query.SqlFieldsQuery;
 import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.internal.IgniteEx;
-import org.apache.ignite.internal.processors.query.h2.IgniteH2Indexing;
-import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.examples.ml.util.MLSandboxDatasets;
+import org.apache.ignite.examples.ml.util.SandboxMLCache;
 import org.apache.ignite.ml.dataset.feature.extractor.impl.BinaryObjectVectorizer;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
@@ -46,30 +46,15 @@ public class DecisionTreeClassificationTrainerSQLTableExample {
     private static final String DUMMY_CACHE_NAME = "dummy_cache";
 
     /**
-     * Training data.
-     */
-    private static final String TRAIN_DATA_RES = "examples/src/main/resources/datasets/titanic_train.csv";
-
-    /**
-     * Test data.
-     */
-    private static final String TEST_DATA_RES = "examples/src/main/resources/datasets/titanic_test.csv";
-
-    /**
      * Run example.
      */
-    public static void main(String[] args) throws IgniteCheckedException {
+    public static void main(String[] args) throws IgniteCheckedException, IOException {
         System.out.println(">>> Decision tree classification trainer example started.");
 
         // Start ignite grid.
         try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
             System.out.println(">>> Ignite grid started.");
 
-            // Use internal API to enable SQL functions disabled by default (the function CSVREAD is used below)
-            // TODO: IGNITE-12903
-            ((IgniteH2Indexing)((IgniteEx)ignite).context().query().getIndexing())
-                .distributedConfiguration().disabledFunctions(new HashSet<>());
-
             // Dummy cache is required to perform SQL queries.
             CacheConfiguration<?, ?> cacheCfg = new CacheConfiguration<>(DUMMY_CACHE_NAME)
                 .setSqlSchema("PUBLIC");
@@ -81,8 +66,8 @@ public class DecisionTreeClassificationTrainerSQLTableExample {
                 System.out.println(">>> Creating table with training data...");
                 cache.query(new SqlFieldsQuery("create table titanic_train (\n" +
                     "    passengerid int primary key,\n" +
-                    "    survived int,\n" +
                     "    pclass int,\n" +
+                    "    survived int,\n" +
                     "    name varchar(255),\n" +
                     "    sex varchar(255),\n" +
                     "    age float,\n" +
@@ -94,14 +79,11 @@ public class DecisionTreeClassificationTrainerSQLTableExample {
                     "    embarked varchar(255)\n" +
                     ") with \"template=partitioned\";")).getAll();
 
-                System.out.println(">>> Filling training data...");
-                cache.query(new SqlFieldsQuery("insert into titanic_train select * from csvread('" +
-                    IgniteUtils.resolveIgnitePath(TRAIN_DATA_RES).getAbsolutePath() + "')")).getAll();
-
                 System.out.println(">>> Creating table with test data...");
                 cache.query(new SqlFieldsQuery("create table titanic_test (\n" +
                     "    passengerid int primary key,\n" +
                     "    pclass int,\n" +
+                    "    survived int,\n" +
                     "    name varchar(255),\n" +
                     "    sex varchar(255),\n" +
                     "    age float,\n" +
@@ -113,9 +95,7 @@ public class DecisionTreeClassificationTrainerSQLTableExample {
                     "    embarked varchar(255)\n" +
                     ") with \"template=partitioned\";")).getAll();
 
-                System.out.println(">>> Filling training data...");
-                cache.query(new SqlFieldsQuery("insert into titanic_test select * from csvread('" +
-                    IgniteUtils.resolveIgnitePath(TEST_DATA_RES).getAbsolutePath() + "')")).getAll();
+                loadTitanicDatasets(ignite, cache);
 
                 System.out.println(">>> Prepare trainer...");
                 DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
@@ -128,6 +108,8 @@ public class DecisionTreeClassificationTrainerSQLTableExample {
                         .labeled("survived")
                 );
 
+                System.out.println("Tree is here: " + mdl.toString(true));
+
                 System.out.println(">>> Perform inference...");
                 try (QueryCursor<List<?>> cursor = cache.query(new SqlFieldsQuery("select " +
                     "pclass, " +
@@ -137,13 +119,13 @@ public class DecisionTreeClassificationTrainerSQLTableExample {
                     "parch, " +
                     "fare from titanic_test"))) {
                     for (List<?> passenger : cursor) {
-                        Vector input = VectorUtils.of(new Double[] {
+                        Vector input = VectorUtils.of(new Double[]{
                             asDouble(passenger.get(0)),
                             "male".equals(passenger.get(1)) ? 1.0 : 0.0,
                             asDouble(passenger.get(2)),
                             asDouble(passenger.get(3)),
                             asDouble(passenger.get(4)),
-                            asDouble(passenger.get(5))
+                            asDouble(passenger.get(5)),
                         });
 
                         double prediction = mdl.predict(input);
@@ -153,14 +135,12 @@ public class DecisionTreeClassificationTrainerSQLTableExample {
                 }
 
                 System.out.println(">>> Example completed.");
-            }
-            finally {
+            } finally {
                 cache.query(new SqlFieldsQuery("DROP TABLE titanic_train"));
                 cache.query(new SqlFieldsQuery("DROP TABLE titanic_test"));
                 cache.destroy();
             }
-        }
-        finally {
+        } finally {
             System.out.flush();
         }
     }
@@ -177,11 +157,70 @@ public class DecisionTreeClassificationTrainerSQLTableExample {
             return null;
 
         if (obj instanceof Number) {
-            Number num = (Number)obj;
+            Number num = (Number) obj;
 
             return num.doubleValue();
         }
 
         throw new IllegalArgumentException("Object is expected to be a number [obj=" + obj + "]");
     }
+
+    /**
+     * Loads Titanic dataset into cache.
+     *
+     * @param ignite Ignite instance.
+     * @throws IOException If dataset not found.
+     */
+    static void loadTitanicDatasets(Ignite ignite, IgniteCache<?, ?> cache) throws IOException {
+
+        List<String> titanicDatasetRows = new SandboxMLCache(ignite).loadDataset(MLSandboxDatasets.TITANIC);
+        List<String> train = titanicDatasetRows.subList(0, 1000);
+        List<String> test = titanicDatasetRows.subList(1000, titanicDatasetRows.size());
+
+        insertToCache(cache, train, "titanic_train");
+        insertToCache(cache, test, "titanic_test");
+    }
+
+    /** */
+    private static void insertToCache(IgniteCache<?, ?> cache, List<String> train, String tableName) {
+        SqlFieldsQuery insertTrain = new SqlFieldsQuery("insert into " + tableName + " " +
+            "(passengerid, pclass, survived, name, sex, age, sibsp, parch, ticket, fare, cabin, embarked) " +
+            "values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)");
+
+        int seq = 0;
+        for (String s : train) {
+            String[] line = s.split(";");
+            int pclass = parseInteger(line[0]);
+            int survived = parseInteger(line[1]);
+            String name = line[2];
+            String sex = line[3];
+            double age = parseDouble(line[4]);
+            double sibsp = parseInteger(line[5]);
+            double parch = parseInteger(line[6]);
+            String ticket = line[7];
+            double fare = parseDouble(line[8]);
+            String cabin = line[9];
+            String embarked = line[10];
+            insertTrain.setArgs(seq++, pclass, survived, name, sex, age, sibsp, parch, ticket, fare, cabin, embarked);
+            cache.query(insertTrain);
+        }
+    }
+
+    /** */
+    private static Integer parseInteger(String value) {
+        try {
+            return Integer.valueOf(value);
+        } catch (NumberFormatException e) {
+            return 0;
+        }
+    }
+
+    /** */
+    private static Double parseDouble(String value) {
+        try {
+            return Double.valueOf(value);
+        } catch (NumberFormatException e) {
+            return 0.0;
+        }
+    }
 }
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/package-info.java
index c5cdf08..f74a607 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/package-info.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/package-info.java
@@ -16,7 +16,7 @@
  */
 
 /**
- * <!-- Package description. --> Package for model evaluator classes.
+ * <!-- Package description. -->
+ * Package for model evaluator classes.
  */
-
 package org.apache.ignite.ml.selection.scoring.evaluator;