You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by za...@apache.org on 2020/06/26 13:39:52 UTC
[ignite] branch master updated: IGNITE-12903 Fixed ML + SQL
examples (#7965)
This is an automated email from the ASF dual-hosted git repository.
zaleslaw pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ignite.git
The following commit(s) were added to refs/heads/master by this push:
new 40377b1 IGNITE-12903 Fixed ML + SQL examples (#7965)
40377b1 is described below
commit 40377b109053d2e576ace7c612bf006aff9ef76d
Author: Alexey Zinoviev <za...@gmail.com>
AuthorDate: Fri Jun 26 16:39:29 2020 +0300
IGNITE-12903 Fixed ML + SQL examples (#7965)
* [IGNITE-12903] Fixed ML + SQL examples
* [IGNITE-12903] Fixed ML + SQL examples
---
...eeClassificationTrainerSQLInferenceExample.java | 36 ++-----
...onTreeClassificationTrainerSQLTableExample.java | 109 ++++++++++++++-------
.../selection/scoring/evaluator/package-info.java | 4 +-
3 files changed, 83 insertions(+), 66 deletions(-)
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLInferenceExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLInferenceExample.java
index ab2a00c..543e211 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLInferenceExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLInferenceExample.java
@@ -17,18 +17,14 @@
package org.apache.ignite.examples.ml.sql;
-import java.util.HashSet;
+import java.io.IOException;
import java.util.List;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
-import org.apache.ignite.IgniteCheckedException;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.SqlFieldsQuery;
import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.internal.IgniteEx;
-import org.apache.ignite.internal.processors.query.h2.IgniteH2Indexing;
-import org.apache.ignite.internal.util.IgniteUtils;
import org.apache.ignite.ml.dataset.feature.extractor.impl.BinaryObjectVectorizer;
import org.apache.ignite.ml.inference.IgniteModelStorageUtil;
import org.apache.ignite.ml.sql.SQLFunctions;
@@ -36,6 +32,8 @@ import org.apache.ignite.ml.sql.SqlDatasetBuilder;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
import org.apache.ignite.ml.tree.DecisionTreeNode;
+import static org.apache.ignite.examples.ml.sql.DecisionTreeClassificationTrainerSQLTableExample.loadTitanicDatasets;
+
/**
* Example of using distributed {@link DecisionTreeClassificationTrainer} on a data stored in SQL table and inference
* made as SQL select query.
@@ -47,30 +45,15 @@ public class DecisionTreeClassificationTrainerSQLInferenceExample {
private static final String DUMMY_CACHE_NAME = "dummy_cache";
/**
- * Training data.
- */
- private static final String TRAIN_DATA_RES = "examples/src/main/resources/datasets/titanic_train.csv";
-
- /**
- * Test data.
- */
- private static final String TEST_DATA_RES = "examples/src/main/resources/datasets/titanic_test.csv";
-
- /**
* Run example.
*/
- public static void main(String[] args) throws IgniteCheckedException {
+ public static void main(String[] args) throws IOException {
System.out.println(">>> Decision tree classification trainer example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite-ml.xml")) {
System.out.println(">>> Ignite grid started.");
- // Use internal API to enable SQL functions disabled by default (the function CSVREAD is used below)
- // TODO: IGNITE-12903
- ((IgniteH2Indexing)((IgniteEx)ignite).context().query().getIndexing())
- .distributedConfiguration().disabledFunctions(new HashSet<>());
-
// Dummy cache is required to perform SQL queries.
CacheConfiguration<?, ?> cacheCfg = new CacheConfiguration<>(DUMMY_CACHE_NAME)
.setSqlSchema("PUBLIC")
@@ -83,8 +66,8 @@ public class DecisionTreeClassificationTrainerSQLInferenceExample {
System.out.println(">>> Creating table with training data...");
cache.query(new SqlFieldsQuery("create table titanic_train (\n" +
" passengerid int primary key,\n" +
- " survived int,\n" +
" pclass int,\n" +
+ " survived int,\n" +
" name varchar(255),\n" +
" sex varchar(255),\n" +
" age float,\n" +
@@ -96,14 +79,11 @@ public class DecisionTreeClassificationTrainerSQLInferenceExample {
" embarked varchar(255)\n" +
") with \"template=partitioned\";")).getAll();
- System.out.println(">>> Filling training data...");
- cache.query(new SqlFieldsQuery("insert into titanic_train select * from csvread('" +
- IgniteUtils.resolveIgnitePath(TRAIN_DATA_RES).getAbsolutePath() + "')")).getAll();
-
System.out.println(">>> Creating table with test data...");
cache.query(new SqlFieldsQuery("create table titanic_test (\n" +
" passengerid int primary key,\n" +
" pclass int,\n" +
+ " survived int,\n" +
" name varchar(255),\n" +
" sex varchar(255),\n" +
" age float,\n" +
@@ -115,9 +95,7 @@ public class DecisionTreeClassificationTrainerSQLInferenceExample {
" embarked varchar(255)\n" +
") with \"template=partitioned\";")).getAll();
- System.out.println(">>> Filling training data...");
- cache.query(new SqlFieldsQuery("insert into titanic_test select * from csvread('" +
- IgniteUtils.resolveIgnitePath(TEST_DATA_RES).getAbsolutePath() + "')")).getAll();
+ loadTitanicDatasets(ignite, cache);
System.out.println(">>> Prepare trainer...");
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLTableExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLTableExample.java
index 5fe123c..083608e 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLTableExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLTableExample.java
@@ -17,8 +17,9 @@
package org.apache.ignite.examples.ml.sql;
-import java.util.HashSet;
+import java.io.IOException;
import java.util.List;
+
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.IgniteCheckedException;
@@ -26,9 +27,8 @@ import org.apache.ignite.Ignition;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.SqlFieldsQuery;
import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.internal.IgniteEx;
-import org.apache.ignite.internal.processors.query.h2.IgniteH2Indexing;
-import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.examples.ml.util.MLSandboxDatasets;
+import org.apache.ignite.examples.ml.util.SandboxMLCache;
import org.apache.ignite.ml.dataset.feature.extractor.impl.BinaryObjectVectorizer;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
@@ -46,30 +46,15 @@ public class DecisionTreeClassificationTrainerSQLTableExample {
private static final String DUMMY_CACHE_NAME = "dummy_cache";
/**
- * Training data.
- */
- private static final String TRAIN_DATA_RES = "examples/src/main/resources/datasets/titanic_train.csv";
-
- /**
- * Test data.
- */
- private static final String TEST_DATA_RES = "examples/src/main/resources/datasets/titanic_test.csv";
-
- /**
* Run example.
*/
- public static void main(String[] args) throws IgniteCheckedException {
+ public static void main(String[] args) throws IgniteCheckedException, IOException {
System.out.println(">>> Decision tree classification trainer example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- // Use internal API to enable SQL functions disabled by default (the function CSVREAD is used below)
- // TODO: IGNITE-12903
- ((IgniteH2Indexing)((IgniteEx)ignite).context().query().getIndexing())
- .distributedConfiguration().disabledFunctions(new HashSet<>());
-
// Dummy cache is required to perform SQL queries.
CacheConfiguration<?, ?> cacheCfg = new CacheConfiguration<>(DUMMY_CACHE_NAME)
.setSqlSchema("PUBLIC");
@@ -81,8 +66,8 @@ public class DecisionTreeClassificationTrainerSQLTableExample {
System.out.println(">>> Creating table with training data...");
cache.query(new SqlFieldsQuery("create table titanic_train (\n" +
" passengerid int primary key,\n" +
- " survived int,\n" +
" pclass int,\n" +
+ " survived int,\n" +
" name varchar(255),\n" +
" sex varchar(255),\n" +
" age float,\n" +
@@ -94,14 +79,11 @@ public class DecisionTreeClassificationTrainerSQLTableExample {
" embarked varchar(255)\n" +
") with \"template=partitioned\";")).getAll();
- System.out.println(">>> Filling training data...");
- cache.query(new SqlFieldsQuery("insert into titanic_train select * from csvread('" +
- IgniteUtils.resolveIgnitePath(TRAIN_DATA_RES).getAbsolutePath() + "')")).getAll();
-
System.out.println(">>> Creating table with test data...");
cache.query(new SqlFieldsQuery("create table titanic_test (\n" +
" passengerid int primary key,\n" +
" pclass int,\n" +
+ " survived int,\n" +
" name varchar(255),\n" +
" sex varchar(255),\n" +
" age float,\n" +
@@ -113,9 +95,7 @@ public class DecisionTreeClassificationTrainerSQLTableExample {
" embarked varchar(255)\n" +
") with \"template=partitioned\";")).getAll();
- System.out.println(">>> Filling training data...");
- cache.query(new SqlFieldsQuery("insert into titanic_test select * from csvread('" +
- IgniteUtils.resolveIgnitePath(TEST_DATA_RES).getAbsolutePath() + "')")).getAll();
+ loadTitanicDatasets(ignite, cache);
System.out.println(">>> Prepare trainer...");
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
@@ -128,6 +108,8 @@ public class DecisionTreeClassificationTrainerSQLTableExample {
.labeled("survived")
);
+ System.out.println("Tree is here: " + mdl.toString(true));
+
System.out.println(">>> Perform inference...");
try (QueryCursor<List<?>> cursor = cache.query(new SqlFieldsQuery("select " +
"pclass, " +
@@ -137,13 +119,13 @@ public class DecisionTreeClassificationTrainerSQLTableExample {
"parch, " +
"fare from titanic_test"))) {
for (List<?> passenger : cursor) {
- Vector input = VectorUtils.of(new Double[] {
+ Vector input = VectorUtils.of(new Double[]{
asDouble(passenger.get(0)),
"male".equals(passenger.get(1)) ? 1.0 : 0.0,
asDouble(passenger.get(2)),
asDouble(passenger.get(3)),
asDouble(passenger.get(4)),
- asDouble(passenger.get(5))
+ asDouble(passenger.get(5)),
});
double prediction = mdl.predict(input);
@@ -153,14 +135,12 @@ public class DecisionTreeClassificationTrainerSQLTableExample {
}
System.out.println(">>> Example completed.");
- }
- finally {
+ } finally {
cache.query(new SqlFieldsQuery("DROP TABLE titanic_train"));
cache.query(new SqlFieldsQuery("DROP TABLE titanic_test"));
cache.destroy();
}
- }
- finally {
+ } finally {
System.out.flush();
}
}
@@ -177,11 +157,70 @@ public class DecisionTreeClassificationTrainerSQLTableExample {
return null;
if (obj instanceof Number) {
- Number num = (Number)obj;
+ Number num = (Number) obj;
return num.doubleValue();
}
throw new IllegalArgumentException("Object is expected to be a number [obj=" + obj + "]");
}
+
+ /**
+ * Loads Titanic dataset into cache.
+ *
+ * @param ignite Ignite instance.
+ * @throws IOException If dataset not found.
+ */
+ static void loadTitanicDatasets(Ignite ignite, IgniteCache<?, ?> cache) throws IOException {
+
+ List<String> titanicDatasetRows = new SandboxMLCache(ignite).loadDataset(MLSandboxDatasets.TITANIC);
+ List<String> train = titanicDatasetRows.subList(0, 1000);
+ List<String> test = titanicDatasetRows.subList(1000, titanicDatasetRows.size());
+
+ insertToCache(cache, train, "titanic_train");
+ insertToCache(cache, test, "titanic_test");
+ }
+
+ /** */
+ private static void insertToCache(IgniteCache<?, ?> cache, List<String> train, String tableName) {
+ SqlFieldsQuery insertTrain = new SqlFieldsQuery("insert into " + tableName + " " +
+ "(passengerid, pclass, survived, name, sex, age, sibsp, parch, ticket, fare, cabin, embarked) " +
+ "values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)");
+
+ int seq = 0;
+ for (String s : train) {
+ String[] line = s.split(";");
+ int pclass = parseInteger(line[0]);
+ int survived = parseInteger(line[1]);
+ String name = line[2];
+ String sex = line[3];
+ double age = parseDouble(line[4]);
+ double sibsp = parseInteger(line[5]);
+ double parch = parseInteger(line[6]);
+ String ticket = line[7];
+ double fare = parseDouble(line[8]);
+ String cabin = line[9];
+ String embarked = line[10];
+ insertTrain.setArgs(seq++, pclass, survived, name, sex, age, sibsp, parch, ticket, fare, cabin, embarked);
+ cache.query(insertTrain);
+ }
+ }
+
+ /** */
+ private static Integer parseInteger(String value) {
+ try {
+ return Integer.valueOf(value);
+ } catch (NumberFormatException e) {
+ return 0;
+ }
+ }
+
+ /** */
+ private static Double parseDouble(String value) {
+ try {
+ return Double.valueOf(value);
+ } catch (NumberFormatException e) {
+ return 0.0;
+ }
+ }
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/package-info.java
index c5cdf08..f74a607 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/package-info.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/package-info.java
@@ -16,7 +16,7 @@
*/
/**
- * <!-- Package description. --> Package for model evaluator classes.
+ * <!-- Package description. -->
+ * Package for model evaluator classes.
*/
-
package org.apache.ignite.ml.selection.scoring.evaluator;