You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by jomach <gi...@git.apache.org> on 2017/10/13 19:46:49 UTC

[GitHub] spark pull request #7842: [SPARK-8542][MLlib]PMML export for Decision Trees

Github user jomach commented on a diff in the pull request:

    https://github.com/apache/spark/pull/7842#discussion_r144642031
  
    --- Diff: mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLTreeModelUtils.scala ---
    @@ -0,0 +1,261 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.mllib.pmml.export
    +
    +import scala.collection.mutable
    +import scala.collection.JavaConverters._
    +
    +import org.dmg.pmml.{Node => PMMLNode, Value => PMMLValue, _}
    +
    +import org.apache.spark.mllib.tree.configuration.{Algo, FeatureType}
    +import org.apache.spark.mllib.tree.configuration.Algo._
    +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node}
    +
    +private[mllib] object PMMLTreeModelUtils {
    +
    +  val FieldNamePrefix = "field_"
    +
    +  def toPMMLTree(dtModel: DecisionTreeModel, modelName: String): (TreeModel, List[DataField]) = {
    +
    +    val miningFunctionType = dtModel.algo match {
    +      case Algo.Classification => MiningFunctionType.CLASSIFICATION
    +      case Algo.Regression => MiningFunctionType.REGRESSION
    +    }
    +
    +    val treeModel = new TreeModel()
    +      .setModelName(modelName)
    +      .setFunctionName(miningFunctionType)
    +      .setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT)
    +
    +    var (rootNode, miningFields, dataFields, classes) = buildStub(dtModel.topNode, dtModel.algo)
    +
    +    // adding predicted classes for classification and target field for regression for completeness
    +    dtModel.algo match {
    +
    +      case Algo.Classification =>
    +        miningFields = miningFields :+ new MiningField()
    +          .setName(FieldName.create("class"))
    +          .setUsageType(FieldUsageType.PREDICTED)
    +
    +        val dataField = new DataField()
    +          .setName(FieldName.create("class"))
    +          .setOpType(OpType.CATEGORICAL)
    +          .addValues(classes: _*)
    +          .setDataType(DataType.DOUBLE)
    +
    +        dataFields = dataFields :+ dataField
    +
    +      case Algo.Regression =>
    +        val targetField = FieldName.create("target")
    +        val dataField = new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)
    +        dataFields = dataFields :+ dataField
    +
    +        miningFields = miningFields :+ new MiningField()
    +          .setName(targetField)
    +          .setUsageType(FieldUsageType.TARGET)
    +
    +    }
    +
    +    val miningSchema = new MiningSchema().addMiningFields(miningFields: _*)
    +
    +    treeModel.setNode(rootNode).setMiningSchema(miningSchema)
    +
    +    (treeModel, dataFields)
    +  }
    +
    +  /** Build a pmml tree stub given the root mllib node. */
    +  private def buildStub(rootDTNode: Node, algo: Algo):
    +    (PMMLNode, List[MiningField], List[DataField], List[PMMLValue]) = {
    +
    +    val miningFields = mutable.MutableList[MiningField]()
    +    val dataFields = mutable.HashMap[String, DataField]()
    +    val classes = mutable.MutableList[Double]()
    +
    +    def buildStubInternal(rootNode: Node, predicate: Predicate): PMMLNode = {
    +
    +      // get rootPMML node for the MLLib node
    +      val rootPMMLNode = new PMMLNode()
    +        .setId(rootNode.id.toString)
    +        .setScore(rootNode.predict.predict.toString)
    +        .setPredicate(predicate)
    +
    +      var leftPredicate: Predicate = new True()
    +      var rightPredicate: Predicate = new True()
    +
    +      if (rootNode.split.isDefined) {
    +        val fieldName = FieldName.create(FieldNamePrefix + rootNode.split.get.feature)
    +        val dataField = getDataField(rootNode, fieldName).get
    +
    +        if (dataFields.get(dataField.getName.getValue).isEmpty) {
    +          dataFields.put(dataField.getName.getValue, dataField)
    +          miningFields += new MiningField()
    +            .setName(dataField.getName)
    +            .setUsageType(FieldUsageType.ACTIVE)
    +
    +        } else if (dataField.getOpType != OpType.CONTINUOUS) {
    +          appendCategories(
    +            dataFields.get(dataField.getName.getValue).get,
    +            dataField.getValues.asScala.toList)
    +        }
    +
    +        leftPredicate = getPredicate(rootNode, Some(dataField.getName), true)
    +        rightPredicate = getPredicate(rootNode, Some(dataField.getName), false)
    +      }
    +      // if left node exist, add the node
    +      if (rootNode.leftNode.isDefined) {
    +        val leftNode = buildStubInternal(rootNode.leftNode.get, leftPredicate)
    +        rootPMMLNode.addNodes(leftNode)
    +      }
    +      // if right node exist, add the node
    +      if (rootNode.rightNode.isDefined) {
    +        val rightNode = buildStubInternal(rootNode.rightNode.get, rightPredicate)
    +        rootPMMLNode.addNodes(rightNode)
    +      }
    +
    +      // add to the list of classes
    +      if (rootNode.isLeaf && (algo == Algo.Classification)) {
    +        classes += rootNode.predict.predict
    +      }
    +
    +      rootPMMLNode
    +    }
    +
    +    val pmmlTreeRootNode = buildStubInternal(rootDTNode, new True())
    +
    --- End diff --
    
    remove blank Line


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org