You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@hivemall.apache.org by GitBox <gi...@apache.org> on 2019/09/27 18:30:00 UTC

[GitHub] [incubator-hivemall] myui edited a comment on issue #199: [WIP][HIVEMALL-171] Tracing functionality for prediction of DecisionTrees

myui edited a comment on issue #199: [WIP][HIVEMALL-171] Tracing functionality for prediction of DecisionTrees
URL: https://github.com/apache/incubator-hivemall/pull/199#issuecomment-536013346
 
 
   ## Usage
   
   ```sql
   select decision_path();
   ```
   
   ```
   usage: decision_path(string modelId, string model, array<double|string>
          features [, const string options] [, optional array<string>
          featureNames=null, optional array<string> classNames=null]) -
          Returns a decision path for each prediction in array<string> [-c]
          [-no_leaf] [-no_sumarize] [-no_verbose]
    -c,--classification                    Predict as classification
                                           [default: not enabled]
    -no_leaf,--disable_leaf_output         Show leaf value [default: not
                                           enabled]
    -no_sumarize,--disable_summarization   Do not summarize decision paths
    -no_verbose,--disable_verbose_output   Disable verbose output [default:
                                           verbose]
   ```
   
   ## Show decision paths for each Decision Tree
   
   ```sql
   SELECT
     t.passengerid,
     decision_path(m.model_id, m.model, t.features, '-classification') 
   FROM
     model_rf m
     LEFT OUTER JOIN -- CROSS JOIN
     test_rf t
   limit 3;
   ```
   
   |passengerid|path|
   |:-|:-|
   |892  | ["2 [0.0] = 0.0","0 [3.0] = 3.0","1 [696.0] != 107.0","7 [7.8292] <= 7.9104","1 [696.0] != 828.0","1 [696.0] != 391.0","0 [0.961038961038961, 0.03896103896103896]"] |
   |1309 | ["2 [0.0] = 0.0","0 [3.0] = 3.0","1 [1306.0] != 107.0","7 [22.3583] > 12.675","1 [1306.0] != 429.0","1 [1306.0] != 65.0","6 [117.0] != 481.0","6 [117.0] != 251.0","0 [0.9466666666666667, 0.05333333333333334]"] |
   |1308 | ["2 [0.0] = 0.0","0 [3.0] = 3.0","1 [1305.0] != 107.0","7 [8.05] > 7.987500000000001","1 [1305.0] != 429.0","1 [1305.0] != 65.0","7 [8.05] <= 8.08125","1 [1305.0] != 338.0","1 [1305.0] != 220.0","6 [889.0] != 558.0","0 [1.0, 0.0]"] |
   
   ```sql
   SELECT
     t.passengerid,
     decision_path(m.model_id, m.model, t.features, '-classification', array('pclass','name','sex','age','sibsp','parch','ticket','fare','cabin','embarked')) 
   FROM
     model_rf m
     LEFT OUTER JOIN -- CROSS JOIN
     test_rf t
   limit 3;
   ```
   
   |passengerid|path|
   |:-|:-|
   | 892  |  ["sex [0.0] = 0.0","pclass [3.0] = 3.0","name [696.0] != 107.0","fare [7.8292] <= 7.9104","name [696.0] != 828.0","name [696.0] != 391.0","0 [0.961038961038961, 0.03896103896103896]"] |
   | 1309 |  ["sex [0.0] = 0.0","pclass [3.0] = 3.0","name [1306.0] != 107.0","fare [22.3583] > 12.675","name [1306.0] != 429.0","name [1306.0] != 65.0","ticket [117.0] != 481.0","ticket [117.0] != 251.0","0 [0.9466666666666667, 0.05333333333333334]"] |
   | 1308 |  ["sex [0.0] = 0.0","pclass [3.0] = 3.0","name [1305.0] != 107.0","fare [8.05] > 7.987500000000001","name [1305.0] != 429.0","name [1305.0] != 65.0","fare [8.05] <= 8.08125","name [1305.0] != 338.0","name [1305.0] != 220.0","ticket [889.0] != 558.0","0 [1.0, 0.0]"] |
   
   ```sql
   SELECT
     t.passengerid,
     decision_path(m.model_id, m.model, t.features, '-classification', array('pclass','name','sex','age','sibsp','parch','ticket','fare','cabin','embarked'), array('no','yes')) 
   FROM
     model_rf m
     LEFT OUTER JOIN -- CROSS JOIN
     test_rf t
   limit 3;
   ```
   
   |passengerid|path|
   |:-|:-|
   | 892  |  ["sex [0.0] = 0.0","pclass [3.0] = 3.0","name [696.0] != 107.0","fare [7.8292] <= 7.9104","name [696.0] != 828.0","name [696.0] != 391.0","no [0.961038961038961, 0.03896103896103896]"] |
   | 1309 |  ["sex [0.0] = 0.0","pclass [3.0] = 3.0","name [1306.0] != 107.0","fare [22.3583] > 12.675","name [1306.0] != 429.0","name [1306.0] != 65.0","ticket [117.0] != 481.0","ticket [117.0] != 251.0","no [0.9466666666666667, 0.05333333333333334]"] |
   | 1308 |  ["sex [0.0] = 0.0","pclass [3.0] = 3.0","name [1305.0] != 107.0","fare [8.05] > 7.987500000000001","name [1305.0] != 429.0","name [1305.0] != 65.0","fare [8.05] <= 8.08125","name [1305.0] != 338.0","name [1305.0] != 220.0","ticket [889.0] != 558.0","no [1.0, 0.0]"] |
   
   ```sql
   SELECT
     t.passengerid,
     decision_path(m.model_id, m.model, t.features, '-classification -no_sumarize') 
   FROM
     model_rf m
     LEFT OUTER JOIN -- CROSS JOIN
     test_rf t
   limit 3;
   ```
   
   |passengerid|path|
   |:-|:-|
   | 892  |  ["2 [0.0] = 0.0","0 [3.0] = 3.0","1 [696.0] != 107.0","7 [7.8292] <= 7.9104","1 [696.0] != 828.0","1 [696.0] != 391.0","0 [0.961038961038961, 0.03896103896103896]"] |
   | 1309 |  ["2 [0.0] = 0.0","0 [3.0] = 3.0","1 [1306.0] != 107.0","7 [22.3583] > 7.9104","1 [1306.0] != 429.0","1 [1306.0] != 65.0","7 [22.3583] > 12.675","6 [117.0] != 481.0","6 [117.0] != 251.0","0 [0.9466666666666667, 0.05333333333333334]"] |
   | 1308 |    ["2 [0.0] = 0.0","0 [3.0] = 3.0","1 [1305.0] != 107.0","7 [8.05] > 7.9104","1 [1305.0] != 429.0","1 [1305.0] != 65.0","7 [8.05] <= 12.675","7 [8.05] <= 10.65205","1 [1305.0] != 338.0","1 [1305.0] != 220.0","6 [889.0] != 558.0","7 [8.05] <= 8.08125","7 [8.05] > 7.987500000000001","0 [1.0, 0.0]"] |
   
   ```sql
   SELECT
     t.passengerid,
     decision_path(m.model_id, m.model, t.features, '-classification -no_sumarize -no_verbose') 
   FROM
     model_rf m
     LEFT OUTER JOIN -- CROSS JOIN
     test_rf t
   limit 3;
   ```
   
   |passengerid|path|
   |:-|:-|
   | 892  |  ["2 = 0.0","0 = 3.0","1 != 107.0","7 <= 7.9104","1 != 828.0","1 != 391.0","0"] |
   | 1309 |  ["2 = 0.0","0 = 3.0","1 != 107.0","7 > 7.9104","1 != 429.0","1 != 65.0","7 > 12.675","6 != 481.0","6 != 251.0","0"] |
   | 1308 |  ["2 = 0.0","0 = 3.0","1 != 107.0","7 > 7.9104","1 != 429.0","1 != 65.0","7 <= 12.675","7 <= 10.65205","1 != 338.0","1 != 220.0","6 != 558.0","7 <= 8.08125","7 > 7.987500000000001","0"] |
   
   ```sql
   SELECT
     t.passengerid,
     decision_path(m.model_id, m.model, t.features, '-classification -no_sumarize -no_verbose -no_leaf') 
   FROM
     model_rf m
     LEFT OUTER JOIN -- CROSS JOIN
     test_rf t
   limit 3;
   ```
   
   |passengerid|path|
   |:-|:-|
   | 892  |  ["2 = 0.0","0 = 3.0","1 != 107.0","7 <= 7.9104","1 != 828.0","1 != 391.0"] |
   | 1309 |  ["2 = 0.0","0 = 3.0","1 != 107.0","7 > 7.9104","1 != 429.0","1 != 65.0","7 > 12.675","6 != 481.0","6 != 251.0"] |
   | 1308 |  ["2 = 0.0","0 = 3.0","1 != 107.0","7 > 7.9104","1 != 429.0","1 != 65.0","7 <= 12.675","7 <= 10.65205","1 != 338.0","1 != 220.0","6 != 558.0","7 <= 8.08125","7 > 7.987500000000001"] |
   
   ```sql
   SELECT
     t.passengerid,
     decision_path(m.model_id, m.model, t.features, '-classification -no_verbose -no_leaf') 
   FROM
     model_rf m
     LEFT OUTER JOIN -- CROSS JOIN
     test_rf t
   limit 3;
   ```
   
   |passengerid|path|
   |:-|:-|
   | 892  |  ["2 = 0.0","0 = 3.0","1 != 107.0","7 <= 7.9104","1 != 828.0","1 != 391.0"] |
   | 1309 |  ["2 = 0.0","0 = 3.0","1 != 107.0","7 > 12.675","1 != 429.0","1 != 65.0","6 != 481.0","6 != 251.0"] |
   | 1308 |  ["2 = 0.0","0 = 3.0","1 != 107.0","7 > 7.987500000000001","1 != 429.0","1 != 65.0","7 <= 8.08125","1 != 338.0","1 != 220.0","6 != 558.0"] |
   
   ```sql
   SELECT
     t.passengerid,
     decision_path(m.model_id, m.model, t.features, '-classification -no_verbose -no_leaf', array('pclass','name','sex','age','sibsp','parch','ticket','fare','cabin','embarked'), array('no','yes')) 
   FROM
     model_rf m
     LEFT OUTER JOIN -- CROSS JOIN
     test_rf t
   limit 3;
   ```
   
   |passengerid|path|
   |:-|:-|
   | 892  |  ["sex = 0.0","pclass = 3.0","name != 107.0","fare <= 7.9104","name != 828.0","name != 391.0"] |
   | 1309 |  ["sex = 0.0","pclass = 3.0","name != 107.0","fare > 12.675","name != 429.0","name != 65.0","ticket != 481.0","ticket != 251.0"] |
   | 1308 |  ["sex = 0.0","pclass = 3.0","name != 107.0","fare > 7.987500000000001","name != 429.0","name != 65.0","fare <= 8.08125","name != 338.0","name != 220.0","ticket != 558.0"] |
   
   ## Show frequently appeared branch
   
   ```sql
   WITH tmp as (
     SELECT
       decision_path(m.model_id, m.model, t.features, '-classification -no_verbose -no_leaf', array('pclass','name','sex','age','sibsp','parch','ticket','fare','cabin','embarked')) as path
     FROM
       model_rf m
       LEFT OUTER JOIN -- CROSS JOIN
       test_rf t
   )
   select
     r.branch,
     count(1) as cnt
   from
     tmp l
     LATERAL VIEW explode(l.path) r as branch
   group by
     r.branch
   order by
     cnt desc
   limit 100;
   ```
   
   |branch|cnt|
   |:-|:-|
   | sex = 0.0       | 112782 |
   | sex != 0.0      | 68905 |
   | pclass = 3.0    | 55387 |
   | pclass != 3.0   | 53513 |
   | embarked != 1.0 | 42619 |
   | parch <= 0.5    | 33110 |
   | pclass != 1.0   | 31776 |
   | ticket != 71.0  | 27084 |
   | ...             | ...  |

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services