You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2018/02/08 08:37:04 UTC

incubator-hivemall git commit: [HIVEMALL-172] Change tree_predict 3rd argument to accept string options

Repository: incubator-hivemall
Updated Branches:
  refs/heads/v0.5.0 2958af0af -> c742ce58e


[HIVEMALL-172] Change tree_predict 3rd argument to accept string options


Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/c742ce58
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/c742ce58
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/c742ce58

Branch: refs/heads/v0.5.0
Commit: c742ce58e94913bf446c3b296a24415676f9ac3b
Parents: 2958af0
Author: Makoto Yui <my...@apache.org>
Authored: Thu Feb 8 17:36:50 2018 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Thu Feb 8 17:36:50 2018 +0900

----------------------------------------------------------------------
 .../hivemall/smile/tools/TreePredictUDF.java    | 63 ++++++++++++++------
 docs/gitbook/binaryclass/news20_rf.md           |  5 +-
 docs/gitbook/binaryclass/titanic_rf.md          | 10 ++--
 docs/gitbook/multiclass/iris_randomforest.md    |  8 ++-
 4 files changed, 60 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c742ce58/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
index 46b8758..ea3bc29 100644
--- a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
+++ b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
@@ -18,6 +18,7 @@
  */
 package hivemall.smile.tools;
 
+import hivemall.UDFWithOptions;
 import hivemall.math.vector.DenseVector;
 import hivemall.math.vector.SparseVector;
 import hivemall.math.vector.Vector;
@@ -37,11 +38,12 @@ import java.util.List;
 import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
 
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
 import org.apache.hadoop.hive.ql.exec.Description;
 import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
 import org.apache.hadoop.hive.ql.udf.UDFType;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
 import org.apache.hadoop.hive.serde2.io.DoubleWritable;
 import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
@@ -53,12 +55,12 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspe
 import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.io.Text;
 
-@Description(
-        name = "tree_predict",
-        value = "_FUNC_(string modelId, string model, array<double|string> features [, const boolean classification])"
-                + " - Returns a prediction result of a random forest")
+@Description(name = "tree_predict",
+        value = "_FUNC_(string modelId, string model, array<double|string> features [, const string options | const boolean classification=false])"
+                + " - Returns a prediction result of a random forest"
+                + " in <int value, array<double> posteriori> for classification and <double> for regression")
 @UDFType(deterministic = true, stateful = false)
-public final class TreePredictUDF extends GenericUDF {
+public final class TreePredictUDF extends UDFWithOptions {
 
     private boolean classification;
     private StringObjectInspector modelOI;
@@ -72,9 +74,25 @@ public final class TreePredictUDF extends GenericUDF {
     private transient Evaluator evaluator;
 
     @Override
+    protected Options getOptions() {
+        Options opts = new Options();
+        opts.addOption("c", "classification", false,
+            "Predict as classification [default: not enabled]");
+        return opts;
+    }
+
+    @Override
+    protected CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException {
+        CommandLine cl = parseOptions(optionValue);
+
+        this.classification = cl.hasOption("classification");
+        return cl;
+    }
+
+    @Override
     public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
         if (argOIs.length != 3 && argOIs.length != 4) {
-            throw new UDFArgumentException("_FUNC_ takes 3 or 4 arguments");
+            throw new UDFArgumentException("tree_predict takes 3 or 4 arguments");
         }
 
         this.modelOI = HiveUtils.asStringOI(argOIs[1]);
@@ -89,15 +107,25 @@ public final class TreePredictUDF extends GenericUDF {
             this.denseInput = false;
         } else {
             throw new UDFArgumentException(
-                "_FUNC_ takes array<double> or array<string> for the second argument: "
+                "tree_predict takes array<double> or array<string> for the second argument: "
                         + listOI.getTypeName());
         }
 
-        boolean classification = false;
         if (argOIs.length == 4) {
-            classification = HiveUtils.getConstBoolean(argOIs[3]);
+            ObjectInspector argOI3 = argOIs[3];
+            if (HiveUtils.isConstBoolean(argOI3)) {
+                this.classification = HiveUtils.getConstBoolean(argOI3);
+            } else if (HiveUtils.isConstString(argOI3)) {
+                String opts = HiveUtils.getConstString(argOI3);
+                processOptions(opts);
+            } else {
+                throw new UDFArgumentException(
+                    "tree_predict expects <const boolean> or <const string> for the fourth argument: "
+                            + argOI3.getTypeName());
+            }
+        } else {
+            this.classification = false;
         }
-        this.classification = classification;
 
         if (classification) {
             List<String> fieldNames = new ArrayList<String>(2);
@@ -105,7 +133,8 @@ public final class TreePredictUDF extends GenericUDF {
             fieldNames.add("value");
             fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
             fieldNames.add("posteriori");
-            fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+            fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
+                PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
             return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
         } else {
             return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
@@ -116,7 +145,7 @@ public final class TreePredictUDF extends GenericUDF {
     public Object evaluate(@Nonnull DeferredObject[] arguments) throws HiveException {
         Object arg0 = arguments[0].get();
         if (arg0 == null) {
-            throw new HiveException("ModelId was null");
+            throw new HiveException("modelId should not be null");
         }
         // Not using string OI for backward compatibilities
         String modelId = arg0.toString();
@@ -134,8 +163,8 @@ public final class TreePredictUDF extends GenericUDF {
         this.featuresProbe = parseFeatures(arg2, featuresProbe);
 
         if (evaluator == null) {
-            this.evaluator = classification ? new ClassificationEvaluator()
-                    : new RegressionEvaluator();
+            this.evaluator =
+                    classification ? new ClassificationEvaluator() : new RegressionEvaluator();
         }
         return evaluator.evaluate(modelId, model, featuresProbe);
     }
@@ -192,8 +221,8 @@ public final class TreePredictUDF extends GenericUDF {
                 }
 
                 if (feature.indexOf(':') != -1) {
-                    throw new UDFArgumentException("Invaliad feature format `<index>:<value>`: "
-                            + col);
+                    throw new UDFArgumentException(
+                        "Invaliad feature format `<index>:<value>`: " + col);
                 }
 
                 final int colIndex = Integer.parseInt(feature);

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c742ce58/docs/gitbook/binaryclass/news20_rf.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/binaryclass/news20_rf.md b/docs/gitbook/binaryclass/news20_rf.md
index fd0b475..327939b 100644
--- a/docs/gitbook/binaryclass/news20_rf.md
+++ b/docs/gitbook/binaryclass/news20_rf.md
@@ -47,7 +47,7 @@ from
 ## Prediction
 
 ```sql
-SET hivevar:classification=true;
+-- SET hivevar:classification=true;
 
 drop table rf_predicted;
 create table rf_predicted
@@ -60,7 +60,8 @@ FROM (
   SELECT
     rowid, 
     m.model_weight,
-    tree_predict(m.model_id, m.model, t.features, ${classification}) as predicted
+    tree_predict(m.model_id, m.model, t.features, "-classification") as predicted
+    -- tree_predict(m.model_id, m.model, t.features, ${classification}) as predicted
   FROM
     rf_model m
     LEFT OUTER JOIN -- CROSS JOIN

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c742ce58/docs/gitbook/binaryclass/titanic_rf.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/binaryclass/titanic_rf.md b/docs/gitbook/binaryclass/titanic_rf.md
index 29784e0..2b54074 100644
--- a/docs/gitbook/binaryclass/titanic_rf.md
+++ b/docs/gitbook/binaryclass/titanic_rf.md
@@ -175,7 +175,7 @@ from
 # Prediction
 
 ```sql
-SET hivevar:classification=true;
+-- SET hivevar:classification=true;
 set hive.auto.convert.join=true;
 SET hive.mapjoin.optimized.hashtable=false;
 SET mapred.reduce.tasks=16;
@@ -202,7 +202,8 @@ FROM (
       -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted
       -- hivemall v0.5-rc.1 or later
       p.model_weight,
-      tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
+	  tree_predict(p.model_id, p.model, t.features, "-classification") as predicted
+	  -- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
       -- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- to use the old model in v0.5-rc.1 or later
     FROM (
       SELECT 
@@ -319,7 +320,7 @@ from
 > [116.12055542977338,960.8569891444097,291.08765260103837,469.74671636586226,163.721292772701,120.784769882858,847.9769298113661,554.4617571355476,346.3500941757221,97.42593940113392]    0.1838351822503962
 
 ```sql
-SET hivevar:classification=true;
+-- SET hivevar:classification=true;
 SET hive.mapjoin.optimized.hashtable=false;
 SET mapred.reduce.tasks=16;
 
@@ -345,7 +346,8 @@ FROM (
       -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted
       -- hivemall v0.5-rc.1 or later
       p.model_weight,
-      tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
+      tree_predict(p.model_id, p.model, t.features, "-classification") as predicted
+      -- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
       -- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- to use the old model in v0.5-rc.1 or later
     FROM (
       SELECT 

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c742ce58/docs/gitbook/multiclass/iris_randomforest.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/multiclass/iris_randomforest.md b/docs/gitbook/multiclass/iris_randomforest.md
index b421297..bfc197f 100644
--- a/docs/gitbook/multiclass/iris_randomforest.md
+++ b/docs/gitbook/multiclass/iris_randomforest.md
@@ -206,7 +206,7 @@ from
 # Prediction
 
 ```sql
-set hivevar:classification=true;
+-- set hivevar:classification=true;
 set hive.auto.convert.join=true;
 set hive.mapjoin.optimized.hashtable=false;
 
@@ -225,7 +225,8 @@ FROM (
     -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted
     -- hivemall v0.5-rc.1 or later
     p.model_weight,
-    tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
+    tree_predict(p.model_id, p.model, t.features, "-classification") as predicted
+    -- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
     -- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- to use the old model in v0.5-rc.1 or later
   FROM
     model p
@@ -265,7 +266,8 @@ FROM (
     -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted
     -- hivemall v0.5-rc.1 or later
     p.model_weight,
-    tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
+    tree_predict(p.model_id, p.model, t.features, "-classification") as predicted
+    -- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
     -- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted as predicted -- to use the old model in v0.5-rc.1 or later
   FROM (
     SELECT