You are viewing a plain text version of this content. The canonical link for it is here.
Posted to user@spark.apache.org by AaronLee <yl...@wish.com.INVALID> on 2020/06/11 23:57:44 UTC

Spark ml how to extract split points from trained decision tree mode

I am following  official spark 2.4.3 tutorial
<https://spark.apache.org/docs/2.4.3/ml-classification-regression.html#decision-tree-classifier>  
trained a decision tree model. How to extract split points from the trained
model?

// model
val dt = new DecisionTreeClassifier()
  .setLabelCol("indexedLabel")
  .setFeaturesCol("indexedFeatures")
  .setMaxBins(10)

// Train model.  This also runs the indexers.
val dtm = dt.fit(trainingData)

// extract bin split points
how to do it                   <- ?



--
Sent from: http://apache-spark-user-list.1001560.n3.nabble.com/

---------------------------------------------------------------------
To unsubscribe e-mail: user-unsubscribe@spark.apache.org


Re: Spark ml how to extract split points from trained decision tree mode

Posted by AaronLee <yl...@wish.com.INVALID>.
instead continue explore and debug, switch to sklearn decision tree in the
end ... lol



--
Sent from: http://apache-spark-user-list.1001560.n3.nabble.com/

---------------------------------------------------------------------
To unsubscribe e-mail: user-unsubscribe@spark.apache.org


Re: Spark ml how to extract split points from trained decision tree mode

Posted by AaronLee <yl...@wish.com.INVALID>.
@srowen. You are totally right, the model was not trained correctly. But it
is weird as the dataset I used actually has 50m rows. It has binary label
with 20% positive, and 1 feature in feature vector. Do not understand why it
does not trained correctly 


```
scala> df2.count
res56: Long = 48174858

scala> df2.show
+--------------------+-----+
|            features|label|
+--------------------+-----+
|              [14.0]|  1.0|
|               [2.0]|  0.0|
|               [2.0]|  0.0|
|               [1.0]|  1.0|
|[0.9700000286102295]|  1.0|
|[1.9600000381469727]|  0.0|
|[0.9900000095367432]|  0.0|
|[11.739999771118164]|  1.0|
|               [1.0]|  0.0|
|[0.9800000190734863]|  0.0|
|               [5.0]|  0.0|
| [5.940000057220459]|  1.0|
|              [11.0]|  0.0|
|               [4.0]|  0.0|
|               [1.0]|  1.0|
|[1.9700000286102295]|  0.0|
| [6.989999771118164]|  0.0|
|[0.9700000286102295]|  0.0|
|[0.9700000286102295]|  0.0|
|[0.9900000095367432]|  0.0|
+--------------------+-----+
only showing top 20 rows


scala> df2.printSchema
root
 |-- features: vector (nullable = true)
 |-- label: double (nullable = true)

scala> val dt = new
DecisionTreeClassifier().setLabelCol("label").setFeaturesCol("features").setMaxBins(10)
dt: org.apache.spark.ml.classification.DecisionTreeClassifier =
dtc_2b6b6e170840

scala>  val dtm = dt.fit(df2)
*dtm: org.apache.spark.ml.classification.DecisionTreeClassificationModel =
DecisionTreeClassificationModel (uid=dtc_2b6b6e170840) of depth 0 with 1
nodes
*

scala> val df3 = dtm.transform(df2)
df3: org.apache.spark.sql.DataFrame = [features: vector, label: double ... 3
more fields]

scala>  df3.show(100,false)
+--------------------+-----+----------------------+----------------------------------------+----------+
|features            |label|rawPrediction         |probability                            
|prediction|
+--------------------+-----+----------------------+----------------------------------------+----------+
|[14.0]              |1.0 
|[3.872715E7,9447708.0]|[0.8038871645454565,0.19611283545454353]|0.0       |
|[2.0]               |0.0 
|[3.872715E7,9447708.0]|[0.8038871645454565,0.19611283545454353]|0.0       |
|[2.0]               |0.0 
|[3.872715E7,9447708.0]|[0.8038871645454565,0.19611283545454353]|0.0       |
|[1.0]               |1.0 
|[3.872715E7,9447708.0]|[0.8038871645454565,0.19611283545454353]|0.0       |
|[0.9700000286102295]|1.0 
|[3.872715E7,9447708.0]|[0.8038871645454565,0.19611283545454353]|0.0       |
|[1.9600000381469727]|0.0 
|[3.872715E7,9447708.0]|[0.8038871645454565,0.19611283545454353]|0.0       |
|[0.9900000095367432]|0.0 
|[3.872715E7,9447708.0]|[0.8038871645454565,0.19611283545454353]|0.0       |
....
```




--
Sent from: http://apache-spark-user-list.1001560.n3.nabble.com/

---------------------------------------------------------------------
To unsubscribe e-mail: user-unsubscribe@spark.apache.org


Re: Spark ml how to extract split points from trained decision tree mode

Posted by Sean Owen <sr...@gmail.com>.
Hm, the root is a leaf? it's possible but that means there are no splits.
If it's a toy example, could be.
This was just off the top of my head looking at the code, so could be
missing something, but a non-trivial tree should start with an internalnode.

On Thu, Jun 11, 2020 at 11:01 PM AaronLee <yl...@wish.com.invalid> wrote:

> Thanks srowen. I also checked
> https://www.programcreek.com/scala/org.apache.spark.ml.tree.InternalNode.
> Splits are available via "InternalNode" ".split" attribute. But
> "dtm.rootNode"  belongs to "LeafNode".
>
> ```
> scala> dtm.rootNode
> res9: org.apache.spark.ml.tree.Node = LeafNode(prediction = 0.0, impurity =
> 0.3153051824490453)
>
> scala> dftm.rootNode.
> impurity   prediction
>
> scala> dftm.rootNode.getClass.getSimpleName
> res13: String = LeafNode
>
> scala> import org.apache.spark.ml.tree.{InternalNode, LeafNode, Node}
> import org.apache.spark.ml.tree.{InternalNode, LeafNode, Node}
>
> scala> val intnode = dftm.rootNode.asInstanceOf[InternalNode]
> java.lang.ClassCastException: org.apache.spark.ml.tree.LeafNode cannot be
> cast to org.apache.spark.ml.tree.InternalNode
>   ... 51 elided
>
> ```
>
>
>
> --
> Sent from: http://apache-spark-user-list.1001560.n3.nabble.com/
>
> ---------------------------------------------------------------------
> To unsubscribe e-mail: user-unsubscribe@spark.apache.org
>
>

Re: Spark ml how to extract split points from trained decision tree mode

Posted by AaronLee <yl...@wish.com.INVALID>.
Thanks srowen. I also checked
https://www.programcreek.com/scala/org.apache.spark.ml.tree.InternalNode.
Splits are available via "InternalNode" ".split" attribute. But
"dtm.rootNode"  belongs to "LeafNode". 

```
scala> dtm.rootNode
res9: org.apache.spark.ml.tree.Node = LeafNode(prediction = 0.0, impurity =
0.3153051824490453)

scala> dftm.rootNode.
impurity   prediction

scala> dftm.rootNode.getClass.getSimpleName
res13: String = LeafNode

scala> import org.apache.spark.ml.tree.{InternalNode, LeafNode, Node}
import org.apache.spark.ml.tree.{InternalNode, LeafNode, Node}

scala> val intnode = dftm.rootNode.asInstanceOf[InternalNode]
java.lang.ClassCastException: org.apache.spark.ml.tree.LeafNode cannot be
cast to org.apache.spark.ml.tree.InternalNode
  ... 51 elided

```



--
Sent from: http://apache-spark-user-list.1001560.n3.nabble.com/

---------------------------------------------------------------------
To unsubscribe e-mail: user-unsubscribe@spark.apache.org


Re: Spark ml how to extract split points from trained decision tree mode

Posted by Sean Owen <sr...@gmail.com>.
You should be able to look at dtm.rootNode and, treating it as an
InternalNode, get the .split from it

On Thu, Jun 11, 2020 at 7:02 PM AaronLee <yl...@wish.com.invalid> wrote:

> I am following  official spark 2.4.3 tutorial
> <
> https://spark.apache.org/docs/2.4.3/ml-classification-regression.html#decision-tree-classifier>
>
> trained a decision tree model. How to extract split points from the trained
> model?
>
> // model
> val dt = new DecisionTreeClassifier()
>   .setLabelCol("indexedLabel")
>   .setFeaturesCol("indexedFeatures")
>   .setMaxBins(10)
>
> // Train model.  This also runs the indexers.
> val dtm = dt.fit(trainingData)
>
> // extract bin split points
> how to do it                   <- ?
>
>
>
> --
> Sent from: http://apache-spark-user-list.1001560.n3.nabble.com/
>
> ---------------------------------------------------------------------
> To unsubscribe e-mail: user-unsubscribe@spark.apache.org
>
>