You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ch...@apache.org on 2019/01/25 13:07:54 UTC
[ignite] branch master updated: IGNITE-11003: [ML] Add parser for
Spark Random forest classifier
This is an automated email from the ASF dual-hosted git repository.
chief pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ignite.git
The following commit(s) were added to refs/heads/master by this push:
new d60025d IGNITE-11003: [ML] Add parser for Spark Random forest classifier
d60025d is described below
commit d60025d88c2840470f3a6af39884cc112a7205a0
Author: zaleslaw <za...@gmail.com>
AuthorDate: Fri Jan 25 16:07:32 2019 +0300
IGNITE-11003: [ML] Add parser for Spark Random forest classifier
This closes #5924
---
.../modelparser/RandomForestFromSparkExample.java | 85 +++++++++++++++++++++
.../models/spark/serialized/rf/data/._SUCCESS.crc | Bin 0 -> 8 bytes
...-411c-8811-c3205434f5fc-c000.snappy.parquet.crc | Bin 0 -> 1256 bytes
.../models/spark/serialized/rf/data/_SUCCESS | 0
...bc1b-411c-8811-c3205434f5fc-c000.snappy.parquet | Bin 0 -> 159358 bytes
.../spark/serialized/rf/metadata/._SUCCESS.crc | Bin 0 -> 8 bytes
.../spark/serialized/rf/metadata/.part-00000.crc | Bin 0 -> 16 bytes
.../models/spark/serialized/rf/metadata/_SUCCESS | 0
.../models/spark/serialized/rf/metadata/part-00000 | 1 +
.../serialized/rf/treesMetadata/._SUCCESS.crc | Bin 0 -> 8 bytes
...-4f5a-b7fc-043c9ab56e1d-c000.snappy.parquet.crc | Bin 0 -> 148 bytes
.../spark/serialized/rf/treesMetadata/_SUCCESS | 0
...3d4b-4f5a-b7fc-043c9ab56e1d-c000.snappy.parquet | Bin 0 -> 17636 bytes
.../ml/sparkmodelparser/SparkModelParser.java | 70 ++++++++++++++++-
.../ml/sparkmodelparser/SupportedSparkModels.java | 5 +-
15 files changed, 157 insertions(+), 4 deletions(-)
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/RandomForestFromSparkExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/RandomForestFromSparkExample.java
new file mode 100644
index 0000000..07f2512
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/RandomForestFromSparkExample.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.inference.spark.modelparser;
+
+import java.io.FileNotFoundException;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.examples.ml.tutorial.TitanicUtils;
+import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
+import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.sparkmodelparser.SparkModelParser;
+import org.apache.ignite.ml.sparkmodelparser.SupportedSparkModels;
+
+/**
+ * Run Random Forest model loaded from snappy.parquet file.
+ * The snappy.parquet file was generated by Spark MLLib model.write.overwrite().save(..) operator.
+ * <p>
+ * You can change the test data used in this example and re-run it to explore this algorithm further.</p>
+ */
+public class RandomForestFromSparkExample {
+ /** Path to Spark Random Forest model. */
+ public static final String SPARK_MDL_PATH = "examples/src/main/resources/models/spark/serialized/rf/data" +
+ "/part-00000-290bdb9d-bc1b-411c-8811-c3205434f5fc-c000.snappy.parquet";
+
+ /** Run example. */
+ public static void main(String[] args) throws FileNotFoundException {
+ System.out.println();
+ System.out.println(">>> Random Forest model loaded from Spark through serialization over partitioned dataset usage example started.");
+ // Start ignite grid.
+ try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+ System.out.println(">>> Ignite grid started.");
+
+ IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite);
+
+ IgniteBiFunction<Integer, Object[], Vector> featureExtractor = (k, v) -> {
+ double[] data = new double[] {(double)v[0], (double)v[5], (double)v[6]};
+ data[0] = Double.isNaN(data[0]) ? 0 : data[0];
+ data[1] = Double.isNaN(data[1]) ? 0 : data[1];
+ data[2] = Double.isNaN(data[2]) ? 0 : data[2];
+
+ return VectorUtils.of(data);
+ };
+
+ IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[1];
+
+ ModelsComposition mdl = (ModelsComposition)SparkModelParser.parse(
+ SPARK_MDL_PATH,
+ SupportedSparkModels.RANDOM_FOREST
+ );
+
+ System.out.println(">>> Random Forest model: " + mdl.toString(true));
+
+ double accuracy = BinaryClassificationEvaluator.evaluate(
+ dataCache,
+ mdl,
+ featureExtractor,
+ lbExtractor,
+ new Accuracy<>()
+ );
+
+ System.out.println("\n>>> Accuracy " + accuracy);
+ System.out.println("\n>>> Test Error " + (1 - accuracy));
+ }
+ }
+}
diff --git a/examples/src/main/resources/models/spark/serialized/rf/data/._SUCCESS.crc b/examples/src/main/resources/models/spark/serialized/rf/data/._SUCCESS.crc
new file mode 100644
index 0000000..3b7b044
Binary files /dev/null and b/examples/src/main/resources/models/spark/serialized/rf/data/._SUCCESS.crc differ
diff --git a/examples/src/main/resources/models/spark/serialized/rf/data/.part-00000-290bdb9d-bc1b-411c-8811-c3205434f5fc-c000.snappy.parquet.crc b/examples/src/main/resources/models/spark/serialized/rf/data/.part-00000-290bdb9d-bc1b-411c-8811-c3205434f5fc-c000.snappy.parquet.crc
new file mode 100644
index 0000000..c72746b
Binary files /dev/null and b/examples/src/main/resources/models/spark/serialized/rf/data/.part-00000-290bdb9d-bc1b-411c-8811-c3205434f5fc-c000.snappy.parquet.crc differ
diff --git a/examples/src/main/resources/models/spark/serialized/rf/data/_SUCCESS b/examples/src/main/resources/models/spark/serialized/rf/data/_SUCCESS
new file mode 100644
index 0000000..e69de29
diff --git a/examples/src/main/resources/models/spark/serialized/rf/data/part-00000-290bdb9d-bc1b-411c-8811-c3205434f5fc-c000.snappy.parquet b/examples/src/main/resources/models/spark/serialized/rf/data/part-00000-290bdb9d-bc1b-411c-8811-c3205434f5fc-c000.snappy.parquet
new file mode 100644
index 0000000..4c100cc
Binary files /dev/null and b/examples/src/main/resources/models/spark/serialized/rf/data/part-00000-290bdb9d-bc1b-411c-8811-c3205434f5fc-c000.snappy.parquet differ
diff --git a/examples/src/main/resources/models/spark/serialized/rf/metadata/._SUCCESS.crc b/examples/src/main/resources/models/spark/serialized/rf/metadata/._SUCCESS.crc
new file mode 100644
index 0000000..3b7b044
Binary files /dev/null and b/examples/src/main/resources/models/spark/serialized/rf/metadata/._SUCCESS.crc differ
diff --git a/examples/src/main/resources/models/spark/serialized/rf/metadata/.part-00000.crc b/examples/src/main/resources/models/spark/serialized/rf/metadata/.part-00000.crc
new file mode 100644
index 0000000..d05128d
Binary files /dev/null and b/examples/src/main/resources/models/spark/serialized/rf/metadata/.part-00000.crc differ
diff --git a/examples/src/main/resources/models/spark/serialized/rf/metadata/_SUCCESS b/examples/src/main/resources/models/spark/serialized/rf/metadata/_SUCCESS
new file mode 100644
index 0000000..e69de29
diff --git a/examples/src/main/resources/models/spark/serialized/rf/metadata/part-00000 b/examples/src/main/resources/models/spark/serialized/rf/metadata/part-00000
new file mode 100644
index 0000000..1a47ce3
--- /dev/null
+++ b/examples/src/main/resources/models/spark/serialized/rf/metadata/part-00000
@@ -0,0 +1 @@
+{"class":"org.apache.spark.ml.classification.RandomForestClassificationModel","timestamp":1548169635203,"sparkVersion":"2.2.0","uid":"rfc_4627f663b8c3","paramMap":{"featureSubsetStrategy":"auto","maxMemoryInMB":256,"impurity":"gini","numTrees":200,"probabilityCol":"probability","maxDepth":10,"labelCol":"survived","maxBins":32,"subsamplingRate":1.0,"rawPredictionCol":"rawPrediction","checkpointInterval":10,"featuresCol":"features","minInstancesPerNode":1,"predictionCol":"prediction","seed [...]
diff --git a/examples/src/main/resources/models/spark/serialized/rf/treesMetadata/._SUCCESS.crc b/examples/src/main/resources/models/spark/serialized/rf/treesMetadata/._SUCCESS.crc
new file mode 100644
index 0000000..3b7b044
Binary files /dev/null and b/examples/src/main/resources/models/spark/serialized/rf/treesMetadata/._SUCCESS.crc differ
diff --git a/examples/src/main/resources/models/spark/serialized/rf/treesMetadata/.part-00000-86dba495-3d4b-4f5a-b7fc-043c9ab56e1d-c000.snappy.parquet.crc b/examples/src/main/resources/models/spark/serialized/rf/treesMetadata/.part-00000-86dba495-3d4b-4f5a-b7fc-043c9ab56e1d-c000.snappy.parquet.crc
new file mode 100644
index 0000000..70712f6
Binary files /dev/null and b/examples/src/main/resources/models/spark/serialized/rf/treesMetadata/.part-00000-86dba495-3d4b-4f5a-b7fc-043c9ab56e1d-c000.snappy.parquet.crc differ
diff --git a/examples/src/main/resources/models/spark/serialized/rf/treesMetadata/_SUCCESS b/examples/src/main/resources/models/spark/serialized/rf/treesMetadata/_SUCCESS
new file mode 100644
index 0000000..e69de29
diff --git a/examples/src/main/resources/models/spark/serialized/rf/treesMetadata/part-00000-86dba495-3d4b-4f5a-b7fc-043c9ab56e1d-c000.snappy.parquet b/examples/src/main/resources/models/spark/serialized/rf/treesMetadata/part-00000-86dba495-3d4b-4f5a-b7fc-043c9ab56e1d-c000.snappy.parquet
new file mode 100644
index 0000000..cb27b95
Binary files /dev/null and b/examples/src/main/resources/models/spark/serialized/rf/treesMetadata/part-00000-86dba495-3d4b-4f5a-b7fc-043c9ab56e1d-c000.snappy.parquet differ
diff --git a/modules/ml/spark-model-parser/src/main/java/org/apache/ignite/ml/sparkmodelparser/SparkModelParser.java b/modules/ml/spark-model-parser/src/main/java/org/apache/ignite/ml/sparkmodelparser/SparkModelParser.java
index 8156810..e329233 100644
--- a/modules/ml/spark-model-parser/src/main/java/org/apache/ignite/ml/sparkmodelparser/SparkModelParser.java
+++ b/modules/ml/spark-model-parser/src/main/java/org/apache/ignite/ml/sparkmodelparser/SparkModelParser.java
@@ -19,12 +19,17 @@ package org.apache.ignite.ml.sparkmodelparser;
import java.io.File;
import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
import java.util.Map;
import java.util.NavigableMap;
import java.util.TreeMap;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
import org.apache.ignite.ml.inference.Model;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
@@ -72,12 +77,62 @@ public class SparkModelParser {
return loadLinearSVMModel(ignitePathToMdl);
case DECISION_TREE:
return loadDecisionTreeModel(ignitePathToMdl);
+ case RANDOM_FOREST:
+ return loadRandomForestModel(ignitePathToMdl);
default:
throw new UnsupportedSparkModelException(ignitePathToMdl);
}
}
/**
+ * Load Random Forest model.
+ *
+ * @param pathToMdl Path to model.
+ */
+ private static Model loadRandomForestModel(String pathToMdl) {
+ try (ParquetFileReader r = ParquetFileReader.open(HadoopInputFile.fromPath(new Path(pathToMdl), new Configuration()))) {
+ PageReadStore pages;
+
+ final MessageType schema = r.getFooter().getFileMetaData().getSchema();
+ final MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
+ final Map<Integer, TreeMap<Integer, NodeData>> nodesByTreeId = new TreeMap<>();
+
+ while (null != (pages = r.readNextRowGroup())) {
+ final long rows = pages.getRowCount();
+ final RecordReader recordReader = colIO.getRecordReader(pages, new GroupRecordConverter(schema));
+
+ for (int i = 0; i < rows; i++) {
+ final SimpleGroup g = (SimpleGroup)recordReader.read();
+ final int treeID = g.getInteger(0, 0);
+ final SimpleGroup nodeDataGroup = (SimpleGroup)g.getGroup(1, 0);
+
+ NodeData nodeData = extractNodeDataFromParquetRow(nodeDataGroup);
+
+ if (nodesByTreeId.containsKey(treeID)) {
+ Map<Integer, NodeData> nodesByNodeId = nodesByTreeId.get(treeID);
+ nodesByNodeId.put(nodeData.id, nodeData);
+ }
+ else {
+ TreeMap<Integer, NodeData> nodesByNodeId = new TreeMap<>();
+ nodesByNodeId.put(nodeData.id, nodeData);
+ nodesByTreeId.put(treeID, nodesByNodeId);
+ }
+ }
+ }
+
+ final List<IgniteModel<Vector, Double>> models = new ArrayList<>();
+ nodesByTreeId.forEach((key, nodes) -> models.add(buildDecisionTreeModel(nodes)));
+
+ return new ModelsComposition(models, new OnMajorityPredictionsAggregator());
+ }
+ catch (IOException e) {
+ System.out.println("Error reading parquet file.");
+ e.printStackTrace();
+ }
+ return null;
+ }
+
+ /**
* Load Decision Tree model.
*
* @param pathToMdl Path to model.
@@ -85,12 +140,15 @@ public class SparkModelParser {
private static Model loadDecisionTreeModel(String pathToMdl) {
try (ParquetFileReader r = ParquetFileReader.open(HadoopInputFile.fromPath(new Path(pathToMdl), new Configuration()))) {
PageReadStore pages;
+
final MessageType schema = r.getFooter().getFileMetaData().getSchema();
final MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
final Map<Integer, NodeData> nodes = new TreeMap<>();
+
while (null != (pages = r.readNextRowGroup())) {
final long rows = pages.getRowCount();
final RecordReader recordReader = colIO.getRecordReader(pages, new GroupRecordConverter(schema));
+
for (int i = 0; i < rows; i++) {
final SimpleGroup g = (SimpleGroup)recordReader.read();
NodeData nodeData = extractNodeDataFromParquetRow(g);
@@ -111,7 +169,7 @@ public class SparkModelParser {
*
* @param nodes The sorted map of nodes.
*/
- private static Model buildDecisionTreeModel(Map<Integer, NodeData> nodes) {
+ private static DecisionTreeNode buildDecisionTreeModel(Map<Integer, NodeData> nodes) {
DecisionTreeNode mdl = null;
if (!nodes.isEmpty()) {
NodeData rootNodeData = (NodeData)((NavigableMap)nodes).firstEntry().getValue();
@@ -143,6 +201,7 @@ public class SparkModelParser {
*/
@NotNull private static SparkModelParser.NodeData extractNodeDataFromParquetRow(SimpleGroup g) {
NodeData nodeData = new NodeData();
+
nodeData.id = g.getInteger(0, 0);
nodeData.prediction = g.getDouble(1, 0);
nodeData.leftChildId = g.getInteger(5, 0);
@@ -195,6 +254,7 @@ public class SparkModelParser {
try (ParquetFileReader r = ParquetFileReader.open(HadoopInputFile.fromPath(new Path(pathToMdl), new Configuration()))) {
PageReadStore pages;
+
final MessageType schema = r.getFooter().getFileMetaData().getSchema();
final MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
@@ -227,6 +287,7 @@ public class SparkModelParser {
try (ParquetFileReader r = ParquetFileReader.open(HadoopInputFile.fromPath(new Path(pathToMdl), new Configuration()))) {
PageReadStore pages;
+
final MessageType schema = r.getFooter().getFileMetaData().getSchema();
final MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
@@ -260,6 +321,7 @@ public class SparkModelParser {
try (ParquetFileReader r = ParquetFileReader.open(HadoopInputFile.fromPath(new Path(pathToMdl), new Configuration()))) {
PageReadStore pages;
+
final MessageType schema = r.getFooter().getFileMetaData().getSchema();
final MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
@@ -278,9 +340,7 @@ public class SparkModelParser {
System.out.println("Error reading parquet file.");
e.printStackTrace();
}
-
return new LogisticRegressionModel(coefficients, interceptor);
-
}
/**
@@ -350,10 +410,13 @@ public class SparkModelParser {
*/
private static double readInterceptor(SimpleGroup g) {
double interceptor;
+
final SimpleGroup interceptVector = (SimpleGroup)g.getGroup(2, 0);
final SimpleGroup interceptVectorVal = (SimpleGroup)interceptVector.getGroup(3, 0);
final SimpleGroup interceptVectorValElement = (SimpleGroup)interceptVectorVal.getGroup(0, 0);
+
interceptor = interceptVectorValElement.getDouble(0, 0);
+
return interceptor;
}
@@ -401,6 +464,7 @@ public class SparkModelParser {
/** Is leaf node. */
boolean isLeafNode;
+ /** {@inheritDoc} */
@Override public String toString() {
return "NodeData{" +
"id=" + id +
diff --git a/modules/ml/spark-model-parser/src/main/java/org/apache/ignite/ml/sparkmodelparser/SupportedSparkModels.java b/modules/ml/spark-model-parser/src/main/java/org/apache/ignite/ml/sparkmodelparser/SupportedSparkModels.java
index 2064a0c..0c32d8e 100644
--- a/modules/ml/spark-model-parser/src/main/java/org/apache/ignite/ml/sparkmodelparser/SupportedSparkModels.java
+++ b/modules/ml/spark-model-parser/src/main/java/org/apache/ignite/ml/sparkmodelparser/SupportedSparkModels.java
@@ -33,5 +33,8 @@ public enum SupportedSparkModels {
DECISION_TREE,
/** Support Vector Machine . */
- LINEAR_SVM
+ LINEAR_SVM,
+
+ /** Random forest. */
+ RANDOM_FOREST
}