You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2018/02/08 08:37:04 UTC
incubator-hivemall git commit: [HIVEMALL-172] Change tree_predict 3rd
argument to accept string options
Repository: incubator-hivemall
Updated Branches:
refs/heads/v0.5.0 2958af0af -> c742ce58e
[HIVEMALL-172] Change tree_predict 3rd argument to accept string options
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/c742ce58
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/c742ce58
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/c742ce58
Branch: refs/heads/v0.5.0
Commit: c742ce58e94913bf446c3b296a24415676f9ac3b
Parents: 2958af0
Author: Makoto Yui <my...@apache.org>
Authored: Thu Feb 8 17:36:50 2018 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Thu Feb 8 17:36:50 2018 +0900
----------------------------------------------------------------------
.../hivemall/smile/tools/TreePredictUDF.java | 63 ++++++++++++++------
docs/gitbook/binaryclass/news20_rf.md | 5 +-
docs/gitbook/binaryclass/titanic_rf.md | 10 ++--
docs/gitbook/multiclass/iris_randomforest.md | 8 ++-
4 files changed, 60 insertions(+), 26 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c742ce58/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
index 46b8758..ea3bc29 100644
--- a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
+++ b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
@@ -18,6 +18,7 @@
*/
package hivemall.smile.tools;
+import hivemall.UDFWithOptions;
import hivemall.math.vector.DenseVector;
import hivemall.math.vector.SparseVector;
import hivemall.math.vector.Vector;
@@ -37,11 +38,12 @@ import java.util.List;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
@@ -53,12 +55,12 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspe
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
-@Description(
- name = "tree_predict",
- value = "_FUNC_(string modelId, string model, array<double|string> features [, const boolean classification])"
- + " - Returns a prediction result of a random forest")
+@Description(name = "tree_predict",
+ value = "_FUNC_(string modelId, string model, array<double|string> features [, const string options | const boolean classification=false])"
+ + " - Returns a prediction result of a random forest"
+ + " in <int value, array<double> posteriori> for classification and <double> for regression")
@UDFType(deterministic = true, stateful = false)
-public final class TreePredictUDF extends GenericUDF {
+public final class TreePredictUDF extends UDFWithOptions {
private boolean classification;
private StringObjectInspector modelOI;
@@ -72,9 +74,25 @@ public final class TreePredictUDF extends GenericUDF {
private transient Evaluator evaluator;
@Override
+ protected Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("c", "classification", false,
+ "Predict as classification [default: not enabled]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException {
+ CommandLine cl = parseOptions(optionValue);
+
+ this.classification = cl.hasOption("classification");
+ return cl;
+ }
+
+ @Override
public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
if (argOIs.length != 3 && argOIs.length != 4) {
- throw new UDFArgumentException("_FUNC_ takes 3 or 4 arguments");
+ throw new UDFArgumentException("tree_predict takes 3 or 4 arguments");
}
this.modelOI = HiveUtils.asStringOI(argOIs[1]);
@@ -89,15 +107,25 @@ public final class TreePredictUDF extends GenericUDF {
this.denseInput = false;
} else {
throw new UDFArgumentException(
- "_FUNC_ takes array<double> or array<string> for the second argument: "
+ "tree_predict takes array<double> or array<string> for the second argument: "
+ listOI.getTypeName());
}
- boolean classification = false;
if (argOIs.length == 4) {
- classification = HiveUtils.getConstBoolean(argOIs[3]);
+ ObjectInspector argOI3 = argOIs[3];
+ if (HiveUtils.isConstBoolean(argOI3)) {
+ this.classification = HiveUtils.getConstBoolean(argOI3);
+ } else if (HiveUtils.isConstString(argOI3)) {
+ String opts = HiveUtils.getConstString(argOI3);
+ processOptions(opts);
+ } else {
+ throw new UDFArgumentException(
+ "tree_predict expects <const boolean> or <const string> for the fourth argument: "
+ + argOI3.getTypeName());
+ }
+ } else {
+ this.classification = false;
}
- this.classification = classification;
if (classification) {
List<String> fieldNames = new ArrayList<String>(2);
@@ -105,7 +133,8 @@ public final class TreePredictUDF extends GenericUDF {
fieldNames.add("value");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
fieldNames.add("posteriori");
- fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
} else {
return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
@@ -116,7 +145,7 @@ public final class TreePredictUDF extends GenericUDF {
public Object evaluate(@Nonnull DeferredObject[] arguments) throws HiveException {
Object arg0 = arguments[0].get();
if (arg0 == null) {
- throw new HiveException("ModelId was null");
+ throw new HiveException("modelId should not be null");
}
// Not using string OI for backward compatibilities
String modelId = arg0.toString();
@@ -134,8 +163,8 @@ public final class TreePredictUDF extends GenericUDF {
this.featuresProbe = parseFeatures(arg2, featuresProbe);
if (evaluator == null) {
- this.evaluator = classification ? new ClassificationEvaluator()
- : new RegressionEvaluator();
+ this.evaluator =
+ classification ? new ClassificationEvaluator() : new RegressionEvaluator();
}
return evaluator.evaluate(modelId, model, featuresProbe);
}
@@ -192,8 +221,8 @@ public final class TreePredictUDF extends GenericUDF {
}
if (feature.indexOf(':') != -1) {
- throw new UDFArgumentException("Invaliad feature format `<index>:<value>`: "
- + col);
+ throw new UDFArgumentException(
+ "Invaliad feature format `<index>:<value>`: " + col);
}
final int colIndex = Integer.parseInt(feature);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c742ce58/docs/gitbook/binaryclass/news20_rf.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/binaryclass/news20_rf.md b/docs/gitbook/binaryclass/news20_rf.md
index fd0b475..327939b 100644
--- a/docs/gitbook/binaryclass/news20_rf.md
+++ b/docs/gitbook/binaryclass/news20_rf.md
@@ -47,7 +47,7 @@ from
## Prediction
```sql
-SET hivevar:classification=true;
+-- SET hivevar:classification=true;
drop table rf_predicted;
create table rf_predicted
@@ -60,7 +60,8 @@ FROM (
SELECT
rowid,
m.model_weight,
- tree_predict(m.model_id, m.model, t.features, ${classification}) as predicted
+ tree_predict(m.model_id, m.model, t.features, "-classification") as predicted
+ -- tree_predict(m.model_id, m.model, t.features, ${classification}) as predicted
FROM
rf_model m
LEFT OUTER JOIN -- CROSS JOIN
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c742ce58/docs/gitbook/binaryclass/titanic_rf.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/binaryclass/titanic_rf.md b/docs/gitbook/binaryclass/titanic_rf.md
index 29784e0..2b54074 100644
--- a/docs/gitbook/binaryclass/titanic_rf.md
+++ b/docs/gitbook/binaryclass/titanic_rf.md
@@ -175,7 +175,7 @@ from
# Prediction
```sql
-SET hivevar:classification=true;
+-- SET hivevar:classification=true;
set hive.auto.convert.join=true;
SET hive.mapjoin.optimized.hashtable=false;
SET mapred.reduce.tasks=16;
@@ -202,7 +202,8 @@ FROM (
-- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted
-- hivemall v0.5-rc.1 or later
p.model_weight,
- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
+ tree_predict(p.model_id, p.model, t.features, "-classification") as predicted
+ -- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
-- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- to use the old model in v0.5-rc.1 or later
FROM (
SELECT
@@ -319,7 +320,7 @@ from
> [116.12055542977338,960.8569891444097,291.08765260103837,469.74671636586226,163.721292772701,120.784769882858,847.9769298113661,554.4617571355476,346.3500941757221,97.42593940113392] 0.1838351822503962
```sql
-SET hivevar:classification=true;
+-- SET hivevar:classification=true;
SET hive.mapjoin.optimized.hashtable=false;
SET mapred.reduce.tasks=16;
@@ -345,7 +346,8 @@ FROM (
-- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted
-- hivemall v0.5-rc.1 or later
p.model_weight,
- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
+ tree_predict(p.model_id, p.model, t.features, "-classification") as predicted
+ -- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
-- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- to use the old model in v0.5-rc.1 or later
FROM (
SELECT
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c742ce58/docs/gitbook/multiclass/iris_randomforest.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/multiclass/iris_randomforest.md b/docs/gitbook/multiclass/iris_randomforest.md
index b421297..bfc197f 100644
--- a/docs/gitbook/multiclass/iris_randomforest.md
+++ b/docs/gitbook/multiclass/iris_randomforest.md
@@ -206,7 +206,7 @@ from
# Prediction
```sql
-set hivevar:classification=true;
+-- set hivevar:classification=true;
set hive.auto.convert.join=true;
set hive.mapjoin.optimized.hashtable=false;
@@ -225,7 +225,8 @@ FROM (
-- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted
-- hivemall v0.5-rc.1 or later
p.model_weight,
- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
+ tree_predict(p.model_id, p.model, t.features, "-classification") as predicted
+ -- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
-- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- to use the old model in v0.5-rc.1 or later
FROM
model p
@@ -265,7 +266,8 @@ FROM (
-- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted
-- hivemall v0.5-rc.1 or later
p.model_weight,
- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
+ tree_predict(p.model_id, p.model, t.features, "-classification") as predicted
+ -- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
-- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted as predicted -- to use the old model in v0.5-rc.1 or later
FROM (
SELECT