You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by jomach <gi...@git.apache.org> on 2017/10/13 19:46:49 UTC
[GitHub] spark pull request #7842: [SPARK-8542][MLlib]PMML export for Decision Trees
Github user jomach commented on a diff in the pull request:
https://github.com/apache/spark/pull/7842#discussion_r144642031
--- Diff: mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLTreeModelUtils.scala ---
@@ -0,0 +1,261 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.pmml.export
+
+import scala.collection.mutable
+import scala.collection.JavaConverters._
+
+import org.dmg.pmml.{Node => PMMLNode, Value => PMMLValue, _}
+
+import org.apache.spark.mllib.tree.configuration.{Algo, FeatureType}
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node}
+
+private[mllib] object PMMLTreeModelUtils {
+
+ val FieldNamePrefix = "field_"
+
+ def toPMMLTree(dtModel: DecisionTreeModel, modelName: String): (TreeModel, List[DataField]) = {
+
+ val miningFunctionType = dtModel.algo match {
+ case Algo.Classification => MiningFunctionType.CLASSIFICATION
+ case Algo.Regression => MiningFunctionType.REGRESSION
+ }
+
+ val treeModel = new TreeModel()
+ .setModelName(modelName)
+ .setFunctionName(miningFunctionType)
+ .setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT)
+
+ var (rootNode, miningFields, dataFields, classes) = buildStub(dtModel.topNode, dtModel.algo)
+
+ // adding predicted classes for classification and target field for regression for completeness
+ dtModel.algo match {
+
+ case Algo.Classification =>
+ miningFields = miningFields :+ new MiningField()
+ .setName(FieldName.create("class"))
+ .setUsageType(FieldUsageType.PREDICTED)
+
+ val dataField = new DataField()
+ .setName(FieldName.create("class"))
+ .setOpType(OpType.CATEGORICAL)
+ .addValues(classes: _*)
+ .setDataType(DataType.DOUBLE)
+
+ dataFields = dataFields :+ dataField
+
+ case Algo.Regression =>
+ val targetField = FieldName.create("target")
+ val dataField = new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)
+ dataFields = dataFields :+ dataField
+
+ miningFields = miningFields :+ new MiningField()
+ .setName(targetField)
+ .setUsageType(FieldUsageType.TARGET)
+
+ }
+
+ val miningSchema = new MiningSchema().addMiningFields(miningFields: _*)
+
+ treeModel.setNode(rootNode).setMiningSchema(miningSchema)
+
+ (treeModel, dataFields)
+ }
+
+ /** Build a pmml tree stub given the root mllib node. */
+ private def buildStub(rootDTNode: Node, algo: Algo):
+ (PMMLNode, List[MiningField], List[DataField], List[PMMLValue]) = {
+
+ val miningFields = mutable.MutableList[MiningField]()
+ val dataFields = mutable.HashMap[String, DataField]()
+ val classes = mutable.MutableList[Double]()
+
+ def buildStubInternal(rootNode: Node, predicate: Predicate): PMMLNode = {
+
+ // get rootPMML node for the MLLib node
+ val rootPMMLNode = new PMMLNode()
+ .setId(rootNode.id.toString)
+ .setScore(rootNode.predict.predict.toString)
+ .setPredicate(predicate)
+
+ var leftPredicate: Predicate = new True()
+ var rightPredicate: Predicate = new True()
+
+ if (rootNode.split.isDefined) {
+ val fieldName = FieldName.create(FieldNamePrefix + rootNode.split.get.feature)
+ val dataField = getDataField(rootNode, fieldName).get
+
+ if (dataFields.get(dataField.getName.getValue).isEmpty) {
+ dataFields.put(dataField.getName.getValue, dataField)
+ miningFields += new MiningField()
+ .setName(dataField.getName)
+ .setUsageType(FieldUsageType.ACTIVE)
+
+ } else if (dataField.getOpType != OpType.CONTINUOUS) {
+ appendCategories(
+ dataFields.get(dataField.getName.getValue).get,
+ dataField.getValues.asScala.toList)
+ }
+
+ leftPredicate = getPredicate(rootNode, Some(dataField.getName), true)
+ rightPredicate = getPredicate(rootNode, Some(dataField.getName), false)
+ }
+ // if left node exist, add the node
+ if (rootNode.leftNode.isDefined) {
+ val leftNode = buildStubInternal(rootNode.leftNode.get, leftPredicate)
+ rootPMMLNode.addNodes(leftNode)
+ }
+ // if right node exist, add the node
+ if (rootNode.rightNode.isDefined) {
+ val rightNode = buildStubInternal(rootNode.rightNode.get, rightPredicate)
+ rootPMMLNode.addNodes(rightNode)
+ }
+
+ // add to the list of classes
+ if (rootNode.isLeaf && (algo == Algo.Classification)) {
+ classes += rootNode.predict.predict
+ }
+
+ rootPMMLNode
+ }
+
+ val pmmlTreeRootNode = buildStubInternal(rootDTNode, new True())
+
--- End diff --
remove blank Line
---
---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org