You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ch...@apache.org on 2019/01/29 10:28:22 UTC

[ignite] branch master updated: IGNITE-11072: [ML] Prepare an example of model inference in SQL

This is an automated email from the ASF dual-hosted git repository.

chief pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ignite.git


The following commit(s) were added to refs/heads/master by this push:
     new 6698216  IGNITE-11072: [ML] Prepare an example of model inference in SQL
6698216 is described below

commit 66982167517e37dad5804d83c2665ac68047278c
Author: Anton Dmitriev <dm...@gmail.com>
AuthorDate: Tue Jan 29 13:28:04 2019 +0300

    IGNITE-11072: [ML] Prepare an example of model inference in SQL
    
    This closes #5941
---
 ...eeClassificationTrainerSQLInferenceExample.java | 274 +++++++++++++++++++++
 1 file changed, 274 insertions(+)

diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLInferenceExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLInferenceExample.java
new file mode 100644
index 0000000..b7ae1de
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLInferenceExample.java
@@ -0,0 +1,274 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.sql;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.io.Serializable;
+import java.util.List;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.binary.BinaryObject;
+import org.apache.ignite.cache.query.QueryCursor;
+import org.apache.ignite.cache.query.SqlFieldsQuery;
+import org.apache.ignite.cache.query.annotations.QuerySqlFunction;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
+import org.apache.ignite.ml.inference.Model;
+import org.apache.ignite.ml.inference.ModelDescriptor;
+import org.apache.ignite.ml.inference.ModelSignature;
+import org.apache.ignite.ml.inference.builder.SingleModelBuilder;
+import org.apache.ignite.ml.inference.parser.IgniteModelParser;
+import org.apache.ignite.ml.inference.reader.ModelStorageModelReader;
+import org.apache.ignite.ml.inference.storage.descriptor.ModelDescriptorStorage;
+import org.apache.ignite.ml.inference.storage.descriptor.ModelDescriptorStorageFactory;
+import org.apache.ignite.ml.inference.storage.model.ModelStorage;
+import org.apache.ignite.ml.inference.storage.model.ModelStorageFactory;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
+import org.apache.ignite.ml.tree.DecisionTreeNode;
+
+/**
+ * Example of using distributed {@link DecisionTreeClassificationTrainer} on a data stored in SQL table and inference
+ * made as SQL select query.
+ */
+public class DecisionTreeClassificationTrainerSQLInferenceExample {
+    /** Dummy cache name. */
+    private static final String DUMMY_CACHE_NAME = "dummy_cache";
+
+    /** Training data. */
+    private static final String TRAIN_DATA_RES = "examples/src/main/resources/datasets/titanik_train.csv";
+
+    /** Test data. */
+    private static final String TEST_DATA_RES = "examples/src/main/resources/datasets/titanik_test.csv";
+
+    /** Run example. */
+    public static void main(String[] args) throws IOException {
+        System.out.println(">>> Decision tree classification trainer example started.");
+
+        // Start ignite grid.
+        try (Ignite ignite = Ignition.start("examples/config/example-ignite-ml.xml")) {
+            System.out.println(">>> Ignite grid started.");
+
+            // Dummy cache is required to perform SQL queries.
+            CacheConfiguration<?, ?> cacheCfg = new CacheConfiguration<>(DUMMY_CACHE_NAME)
+                .setSqlSchema("PUBLIC")
+                .setSqlFunctionClasses(SQLFunctions.class);
+
+            IgniteCache<?, ?> cache = ignite.createCache(cacheCfg);
+
+            System.out.println(">>> Creating table with training data...");
+            cache.query(new SqlFieldsQuery("create table titanik_train (\n" +
+                "    passengerid int primary key,\n" +
+                "    survived int,\n" +
+                "    pclass int,\n" +
+                "    name varchar(255),\n" +
+                "    sex varchar(255),\n" +
+                "    age float,\n" +
+                "    sibsp int,\n" +
+                "    parch int,\n" +
+                "    ticket varchar(255),\n" +
+                "    fare float,\n" +
+                "    cabin varchar(255),\n" +
+                "    embarked varchar(255)\n" +
+                ") with \"template=partitioned\";")).getAll();
+
+            System.out.println(">>> Filling training data...");
+            cache.query(new SqlFieldsQuery("insert into titanik_train select * from csvread('" +
+                IgniteUtils.resolveIgnitePath(TRAIN_DATA_RES).getAbsolutePath() + "')")).getAll();
+
+            System.out.println(">>> Creating table with test data...");
+            cache.query(new SqlFieldsQuery("create table titanik_test (\n" +
+                "    passengerid int primary key,\n" +
+                "    pclass int,\n" +
+                "    name varchar(255),\n" +
+                "    sex varchar(255),\n" +
+                "    age float,\n" +
+                "    sibsp int,\n" +
+                "    parch int,\n" +
+                "    ticket varchar(255),\n" +
+                "    fare float,\n" +
+                "    cabin varchar(255),\n" +
+                "    embarked varchar(255)\n" +
+                ") with \"template=partitioned\";")).getAll();
+
+            System.out.println(">>> Filling training data...");
+            cache.query(new SqlFieldsQuery("insert into titanik_test select * from csvread('" +
+                IgniteUtils.resolveIgnitePath(TEST_DATA_RES).getAbsolutePath() + "')")).getAll();
+
+            System.out.println(">>> Prepare trainer...");
+            DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
+
+            System.out.println(">>> Perform training...");
+            IgniteCache<Integer, BinaryObject> titanicTrainCache = ignite.cache("SQL_PUBLIC_TITANIK_TRAIN");
+            DecisionTreeNode mdl = trainer.fit(
+                // We have to specify ".withKeepBinary(true)" because SQL caches contains only binary objects and this
+                // information has to be passed into the trainer.
+                new CacheBasedDatasetBuilder<>(ignite, titanicTrainCache).withKeepBinary(true),
+                (k, v) -> VectorUtils.of(
+                    // We have to handle null values here to avoid NpE during unboxing.
+                    replaceNull(v.<Integer>field("pclass")),
+                    "male".equals(v.<String>field("sex")) ? 1 : 0,
+                    replaceNull(v.<Double>field("age")),
+                    replaceNull(v.<Integer>field("sibsp")),
+                    replaceNull(v.<Integer>field("parch")),
+                    replaceNull(v.<Double>field("fare"))
+                ),
+                (k, v) -> replaceNull(v.<Integer>field("survived"))
+            );
+
+            System.out.println(">>> Saving model...");
+
+            // Model storage is used to store raw serialized model.
+            System.out.println("Saving model into model storage...");
+            byte[] serializedMdl = serialize((IgniteModel<byte[], byte[]>)i -> {
+                // Here we need to wrap model so that it accepts and returns byte array.
+                try {
+                    Vector input = deserialize(i);
+                    return serialize(mdl.predict(input));
+                }
+                catch (IOException | ClassNotFoundException e) {
+                    throw new RuntimeException(e);
+                }
+            });
+
+            ModelStorage storage = new ModelStorageFactory().getModelStorage(ignite);
+            storage.mkdirs("/");
+            storage.putFile("/my_model", serializedMdl);
+
+            // Model descriptor storage that is used to store model metadata.
+            System.out.println("Saving model descriptor into model descriptor storage...");
+            ModelDescriptor desc = new ModelDescriptor(
+                "MyModel",
+                "My Cool Model",
+                new ModelSignature("", "", ""),
+                new ModelStorageModelReader("/my_model"),
+                new IgniteModelParser<>()
+            );
+            ModelDescriptorStorage descStorage = new ModelDescriptorStorageFactory().getModelDescriptorStorage(ignite);
+            descStorage.put("my_model", desc);
+
+            // Making inference using saved model.
+            System.out.println("Inference...");
+            try (QueryCursor<List<?>> cursor = cache.query(new SqlFieldsQuery("select " +
+                "survived as truth, " +
+                "predict('my_model', pclass, case sex when 'male' then 1 else 0 end, age, sibsp, parch, fare) as prediction " +
+                "from titanik_train"))) {
+                // Print inference result.
+                System.out.println("| Truth | Prediction |");
+                System.out.println("|--------------------|");
+                for (List<?> row : cursor)
+                    System.out.println("|     " + row.get(0) + " |        " + row.get(1) + " |");
+            }
+        }
+    }
+
+    /**
+     * Replaces NULL values by 0.
+     *
+     * @param obj Input value.
+     * @param <T> Type of value.
+     * @return Input value of 0 if value is null.
+     */
+    private static <T extends Number> double replaceNull(T obj) {
+        if (obj == null)
+            return 0;
+
+        return obj.doubleValue();
+    }
+
+    /**
+     * SQL functions that should be defined and passed into cache configuration to extend list of functions available
+     * in SQL interface.
+     */
+    public static class SQLFunctions {
+        /**
+         * Makes prediction using specified model name to extract model from model storage and specified input values
+         * as input object for prediction.
+         *
+         * @param mdl Pretrained model.
+         * @param x Input values.
+         * @return Prediction.
+         */
+        @QuerySqlFunction
+        public static double predict(String mdl, Double... x) {
+            // Pretrained models work with vector of doubles so we need to replace null by 0 (or any other double).
+            for (int i = 0; i < x.length; i++)
+                if (x[i] == null)
+                    x[i] = 0.0;
+
+            Ignite ignite = Ignition.ignite();
+
+            ModelDescriptorStorage descStorage = new ModelDescriptorStorageFactory().getModelDescriptorStorage(ignite);
+            ModelDescriptor desc = descStorage.get(mdl);
+
+            Model<byte[], byte[]> infMdl = new SingleModelBuilder().build(desc.getReader(), desc.getParser());
+
+            Vector input = VectorUtils.of(x);
+
+            try {
+                return deserialize(infMdl.predict(serialize(input)));
+            }
+            catch (IOException | ClassNotFoundException e) {
+                throw new RuntimeException(e);
+            }
+        }
+    }
+
+    /**
+     * Serialized the specified object.
+     *
+     * @param o Object to be serialized.
+     * @return Serialized object as byte array.
+     * @throws IOException In case of exception.
+     */
+    private static <T extends Serializable> byte[] serialize(T o) throws IOException {
+        try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
+             ObjectOutputStream oos = new ObjectOutputStream(baos)) {
+            oos.writeObject(o);
+            oos.flush();
+
+            return baos.toByteArray();
+        }
+    }
+
+    /**
+     * Deserialized object represented as a byte array.
+     *
+     * @param o Serialized object.
+     * @param <T> Type of serialized object.
+     * @return Deserialized object.
+     * @throws IOException In case of exception.
+     * @throws ClassNotFoundException In case of exception.
+     */
+    @SuppressWarnings("unchecked")
+    private static <T extends Serializable> T deserialize(byte[] o) throws IOException, ClassNotFoundException {
+        try (ByteArrayInputStream bais = new ByteArrayInputStream(o);
+             ObjectInputStream ois = new ObjectInputStream(bais)) {
+
+            return (T)ois.readObject();
+        }
+    }
+}