You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/06/30 12:16:18 UTC
[2/2] incubator-hivemall git commit: Close #70: [HIVEMALL-75-2] Add
tree_export UDF and update RandomForest tutorial
Close #70: [HIVEMALL-75-2] Add tree_export UDF and update RandomForest tutorial
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/9876d063
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/9876d063
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/9876d063
Branch: refs/heads/master
Commit: 9876d06316ad6e4ef5b62511c0806d1c3d8c03ce
Parents: 9f01ebf
Author: Makoto Yui <my...@apache.org>
Authored: Fri Jun 30 21:15:54 2017 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Fri Jun 30 21:15:54 2017 +0900
----------------------------------------------------------------------
.../smile/classification/DecisionTree.java | 145 ++++++--
.../GradientTreeBoostingClassifierUDTF.java | 2 +-
.../RandomForestClassifierUDTF.java | 2 +-
.../regression/RandomForestRegressionUDTF.java | 2 +-
.../smile/regression/RegressionTree.java | 124 +++++--
.../hivemall/smile/tools/TreeExportUDF.java | 241 +++++++++++++
.../hivemall/smile/tools/TreePredictUDF.java | 4 +-
.../hivemall/smile/utils/SmileExtUtils.java | 41 +++
.../java/hivemall/utils/hadoop/HiveUtils.java | 19 ++
.../smile/classification/DecisionTreeTest.java | 79 ++++-
.../RandomForestClassifierUDTFTest.java | 4 +-
.../smile/regression/RegressionTreeTest.java | 60 +++-
.../smile/tools/TreePredictUDFTest.java | 4 +-
docs/gitbook/SUMMARY.md | 3 +-
docs/gitbook/binaryclass/news20_rf.md | 90 +++++
docs/gitbook/binaryclass/titanic_rf.md | 93 +++++-
docs/gitbook/ft_engineering/hashing.md | 45 ++-
docs/gitbook/multiclass/iris_dataset.md | 65 +---
docs/gitbook/multiclass/iris_randomforest.md | 259 ++++++++------
docs/gitbook/multiclass/iris_scw.md | 334 +++----------------
docs/gitbook/resources/images/iris.png | Bin 0 -> 92872 bytes
.../ddl/define-all-as-permanent.deprecated.hive | 6 -
resources/ddl/define-all-as-permanent.hive | 3 +
resources/ddl/define-all.deprecated.hive | 6 -
resources/ddl/define-all.hive | 3 +
resources/ddl/define-all.spark | 3 +
resources/ddl/define-udfs.td.hql | 1 +
27 files changed, 1076 insertions(+), 562 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/core/src/main/java/hivemall/smile/classification/DecisionTree.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/classification/DecisionTree.java b/core/src/main/java/hivemall/smile/classification/DecisionTree.java
index 2d086b9..fa97dba 100644
--- a/core/src/main/java/hivemall/smile/classification/DecisionTree.java
+++ b/core/src/main/java/hivemall/smile/classification/DecisionTree.java
@@ -33,6 +33,8 @@
*/
package hivemall.smile.classification;
+import static hivemall.smile.utils.SmileExtUtils.resolveFeatureName;
+import static hivemall.smile.utils.SmileExtUtils.resolveName;
import hivemall.annotations.VisibleForTesting;
import hivemall.math.matrix.Matrix;
import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
@@ -47,6 +49,7 @@ import hivemall.smile.data.Attribute.AttributeType;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.lang.ObjectUtils;
+import hivemall.utils.lang.mutable.MutableInt;
import hivemall.utils.sampling.IntReservoirSampler;
import java.io.Externalizable;
@@ -292,41 +295,114 @@ public final class DecisionTree implements Classifier<Vector> {
}
}
- public void jsCodegen(@Nonnull final StringBuilder builder, final int depth) {
+ public void exportJavascript(@Nonnull final StringBuilder builder,
+ @Nullable final String[] featureNames, @Nullable final String[] classNames,
+ final int depth) {
if (trueChild == null && falseChild == null) {
indent(builder, depth);
- builder.append("").append(output).append(";\n");
+ builder.append("").append(resolveName(output, classNames)).append(";\n");
} else {
+ indent(builder, depth);
if (splitFeatureType == AttributeType.NOMINAL) {
- indent(builder, depth);
- builder.append("if(x[")
- .append(splitFeature)
- .append("] == ")
- .append(splitValue)
- .append(") {\n");
- trueChild.jsCodegen(builder, depth + 1);
- indent(builder, depth);
- builder.append("} else {\n");
- falseChild.jsCodegen(builder, depth + 1);
- indent(builder, depth);
- builder.append("}\n");
+ if (featureNames == null) {
+ builder.append("if( x[")
+ .append(splitFeature)
+ .append("] == ")
+ .append(splitValue)
+ .append(" ) {\n");
+ } else {
+ builder.append("if( ")
+ .append(resolveFeatureName(splitFeature, featureNames))
+ .append(" == ")
+ .append(splitValue)
+ .append(" ) {\n");
+ }
} else if (splitFeatureType == AttributeType.NUMERIC) {
- indent(builder, depth);
- builder.append("if(x[")
- .append(splitFeature)
- .append("] <= ")
- .append(splitValue)
- .append(") {\n");
- trueChild.jsCodegen(builder, depth + 1);
- indent(builder, depth);
- builder.append("} else {\n");
- falseChild.jsCodegen(builder, depth + 1);
- indent(builder, depth);
- builder.append("}\n");
+ if (featureNames == null) {
+ builder.append("if( x[")
+ .append(splitFeature)
+ .append("] <= ")
+ .append(splitValue)
+ .append(" ) {\n");
+ } else {
+ builder.append("if( ")
+ .append(resolveFeatureName(splitFeature, featureNames))
+ .append(" <= ")
+ .append(splitValue)
+ .append(" ) {\n");
+ }
} else {
throw new IllegalStateException("Unsupported attribute type: "
+ splitFeatureType);
}
+ trueChild.exportJavascript(builder, featureNames, classNames, depth + 1);
+ indent(builder, depth);
+ builder.append("} else {\n");
+ falseChild.exportJavascript(builder, featureNames, classNames, depth + 1);
+ indent(builder, depth);
+ builder.append("}\n");
+ }
+ }
+
+ public void exportGraphviz(@Nonnull final StringBuilder builder,
+ @Nullable final String[] featureNames, @Nullable final String[] classNames,
+ @Nonnull final String outputName, @Nullable double[] colorBrew,
+ final @Nonnull MutableInt nodeIdGenerator, final int parentNodeId) {
+ final int myNodeId = nodeIdGenerator.getValue();
+
+ if (trueChild == null && falseChild == null) {
+ // fillcolor=h,s,v
+ // https://en.wikipedia.org/wiki/HSL_and_HSV
+ // http://www.graphviz.org/doc/info/attrs.html#k:colorList
+ String hsvColor = (colorBrew == null || output >= colorBrew.length) ? "#00000000"
+ : String.format("%.4f,1.000,1.000", colorBrew[output]);
+ builder.append(String.format(
+ " %d [label=<%s = %s>, fillcolor=\"%s\", shape=ellipse];\n", myNodeId,
+ outputName, resolveName(output, classNames), hsvColor));
+
+ if (myNodeId != parentNodeId) {
+ builder.append(' ').append(parentNodeId).append(" -> ").append(myNodeId);
+ if (parentNodeId == 0) {
+ if (myNodeId == 1) {
+ builder.append(" [labeldistance=2.5, labelangle=45, headlabel=\"True\"]");
+ } else {
+ builder.append(" [labeldistance=2.5, labelangle=-45, headlabel=\"False\"]");
+ }
+ }
+ builder.append(";\n");
+ }
+ } else {
+ if (splitFeatureType == AttributeType.NOMINAL) {
+ builder.append(String.format(
+ " %d [label=<%s = %s>, fillcolor=\"#00000000\"];\n", myNodeId,
+ resolveFeatureName(splitFeature, featureNames), Double.toString(splitValue)));
+ } else if (splitFeatureType == AttributeType.NUMERIC) {
+ builder.append(String.format(
+ " %d [label=<%s ≤ %s>, fillcolor=\"#00000000\"];\n", myNodeId,
+ resolveFeatureName(splitFeature, featureNames), Double.toString(splitValue)));
+ } else {
+ throw new IllegalStateException("Unsupported attribute type: "
+ + splitFeatureType);
+ }
+
+ if (myNodeId != parentNodeId) {
+ builder.append(' ').append(parentNodeId).append(" -> ").append(myNodeId);
+ if (parentNodeId == 0) {//only draw edge label on top
+ if (myNodeId == 1) {
+ builder.append(" [labeldistance=2.5, labelangle=45, headlabel=\"True\"]");
+ } else {
+ builder.append(" [labeldistance=2.5, labelangle=-45, headlabel=\"False\"]");
+ }
+ }
+ builder.append(";\n");
+ }
+
+ nodeIdGenerator.addValue(1);
+ trueChild.exportGraphviz(builder, featureNames, classNames, outputName, colorBrew,
+ nodeIdGenerator, myNodeId);
+ nodeIdGenerator.addValue(1);
+ falseChild.exportGraphviz(builder, featureNames, classNames, outputName, colorBrew,
+ nodeIdGenerator, myNodeId);
}
}
@@ -910,6 +986,11 @@ public final class DecisionTree implements Classifier<Vector> {
}
}
+ @VisibleForTesting
+ Node getRootNode() {
+ return _root;
+ }
+
private static void checkArgument(@Nonnull Matrix x, @Nonnull int[] y, int numVars,
int maxDepth, int maxLeafs, int minSplits, int minLeafSize) {
if (x.numRows() != y.length) {
@@ -965,14 +1046,15 @@ public final class DecisionTree implements Classifier<Vector> {
throw new UnsupportedOperationException("Not supported.");
}
- public String predictJsCodegen() {
+ public String predictJsCodegen(@Nonnull final String[] featureNames,
+ @Nonnull final String[] classNames) {
StringBuilder buf = new StringBuilder(1024);
- _root.jsCodegen(buf, 0);
+ _root.exportJavascript(buf, featureNames, classNames, 0);
return buf.toString();
}
@Nonnull
- public byte[] predictSerCodegen(boolean compress) throws HiveException {
+ public byte[] serialize(boolean compress) throws HiveException {
try {
if (compress) {
return ObjectUtils.toCompressedBytes(_root);
@@ -986,7 +1068,8 @@ public final class DecisionTree implements Classifier<Vector> {
}
}
- public static Node deserializeNode(final byte[] serializedObj, final int length,
+ @Nonnull
+ public static Node deserialize(@Nonnull final byte[] serializedObj, final int length,
final boolean compressed) throws HiveException {
final Node root = new Node();
try {
@@ -1006,7 +1089,7 @@ public final class DecisionTree implements Classifier<Vector> {
@Override
public String toString() {
- return _root == null ? "" : predictJsCodegen();
+ return _root == null ? "" : predictJsCodegen(null, null);
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java b/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
index a380a11..adb405f 100644
--- a/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
+++ b/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
@@ -579,7 +579,7 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
final int m = trees.length;
final Text[] models = new Text[m];
for (int i = 0; i < m; i++) {
- byte[] b = trees[i].predictSerCodegen(true);
+ byte[] b = trees[i].serialize(true);
b = Base91.encode(b);
models[i] = new Text(b);
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java b/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
index 5a831df..59f52d3 100644
--- a/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
+++ b/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
@@ -603,7 +603,7 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
@Nonnull
private static Text getModel(@Nonnull final DecisionTree tree) throws HiveException {
- byte[] b = tree.predictSerCodegen(true);
+ byte[] b = tree.serialize(true);
b = Base91.encode(b);
return new Text(b);
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java b/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
index 557df21..58151e4 100644
--- a/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
+++ b/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
@@ -499,7 +499,7 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
@Nonnull
private static Text getModel(@Nonnull final RegressionTree tree) throws HiveException {
- byte[] b = tree.predictSerCodegen(true);
+ byte[] b = tree.serialize(true);
b = Base91.encode(b);
return new Text(b);
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/core/src/main/java/hivemall/smile/regression/RegressionTree.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/regression/RegressionTree.java b/core/src/main/java/hivemall/smile/regression/RegressionTree.java
index da7e80b..81b9ba8 100755
--- a/core/src/main/java/hivemall/smile/regression/RegressionTree.java
+++ b/core/src/main/java/hivemall/smile/regression/RegressionTree.java
@@ -33,6 +33,8 @@
*/
package hivemall.smile.regression;
+import static hivemall.smile.utils.SmileExtUtils.resolveFeatureName;
+import static hivemall.smile.utils.SmileExtUtils.resolveName;
import hivemall.annotations.VisibleForTesting;
import hivemall.math.matrix.Matrix;
import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
@@ -48,6 +50,7 @@ import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.collections.sets.IntArraySet;
import hivemall.utils.collections.sets.IntSet;
import hivemall.utils.lang.ObjectUtils;
+import hivemall.utils.lang.mutable.MutableInt;
import hivemall.utils.math.MathUtils;
import java.io.Externalizable;
@@ -246,35 +249,52 @@ public final class RegressionTree implements Regression<Vector> {
}
}
- public void jsCodegen(@Nonnull final StringBuilder builder, final int depth) {
+ public void exportJavascript(@Nonnull final StringBuilder builder,
+ @Nullable final String[] featureNames, final int depth) {
if (trueChild == null && falseChild == null) {
indent(builder, depth);
- builder.append("").append(output).append(";\n");
+ builder.append(output).append(";\n");
} else {
if (splitFeatureType == AttributeType.NOMINAL) {
indent(builder, depth);
- builder.append("if(x[")
- .append(splitFeature)
- .append("] == ")
- .append(splitValue)
- .append(") {\n");
- trueChild.jsCodegen(builder, depth + 1);
+ if (featureNames == null) {
+ builder.append("if( x[")
+ .append(splitFeature)
+ .append("] == ")
+ .append(splitValue)
+ .append(") {\n");
+ } else {
+ builder.append("if( ")
+ .append(resolveFeatureName(splitFeature, featureNames))
+ .append(" == ")
+ .append(splitValue)
+ .append(") {\n");
+ }
+ trueChild.exportJavascript(builder, featureNames, depth + 1);
indent(builder, depth);
builder.append("} else {\n");
- falseChild.jsCodegen(builder, depth + 1);
+ falseChild.exportJavascript(builder, featureNames, depth + 1);
indent(builder, depth);
builder.append("}\n");
} else if (splitFeatureType == AttributeType.NUMERIC) {
indent(builder, depth);
- builder.append("if(x[")
- .append(splitFeature)
- .append("] <= ")
- .append(splitValue)
- .append(") {\n");
- trueChild.jsCodegen(builder, depth + 1);
+ if (featureNames == null) {
+ builder.append("if( x[")
+ .append(splitFeature)
+ .append("] <= ")
+ .append(splitValue)
+ .append(") {\n");
+ } else {
+ builder.append("if( ")
+ .append(resolveFeatureName(splitFeature, featureNames))
+ .append(" <= ")
+ .append(splitValue)
+ .append(") {\n");
+ }
+ trueChild.exportJavascript(builder, featureNames, depth + 1);
indent(builder, depth);
- builder.append("} else {\n");
- falseChild.jsCodegen(builder, depth + 1);
+ builder.append("} else {\n");
+ falseChild.exportJavascript(builder, featureNames, depth + 1);
indent(builder, depth);
builder.append("}\n");
} else {
@@ -284,6 +304,63 @@ public final class RegressionTree implements Regression<Vector> {
}
}
+ public void exportGraphviz(@Nonnull final StringBuilder builder,
+ @Nullable final String[] featureNames, @Nonnull final String outputName,
+ final @Nonnull MutableInt nodeIdGenerator, final int parentNodeId) {
+ final int myNodeId = nodeIdGenerator.getValue();
+
+ if (trueChild == null && falseChild == null) {
+ builder.append(String.format(
+ " %d [label=<%s = %s>, fillcolor=\"#00000000\", shape=ellipse];\n", myNodeId,
+ outputName, Double.toString(output)));
+
+ if (myNodeId != parentNodeId) {
+ builder.append(' ').append(parentNodeId).append(" -> ").append(myNodeId);
+ if (parentNodeId == 0) {
+ if (myNodeId == 1) {
+ builder.append(" [labeldistance=2.5, labelangle=45, headlabel=\"True\"]");
+ } else {
+ builder.append(" [labeldistance=2.5, labelangle=-45, headlabel=\"False\"]");
+ }
+ }
+ builder.append(";\n");
+ }
+ } else {
+ if (splitFeatureType == AttributeType.NOMINAL) {
+ builder.append(String.format(
+ " %d [label=<%s = %s>, fillcolor=\"#00000000\"];\n", myNodeId,
+ resolveFeatureName(splitFeature, featureNames), Double.toString(splitValue)));
+ } else if (splitFeatureType == AttributeType.NUMERIC) {
+ builder.append(String.format(
+ " %d [label=<%s ≤ %s>, fillcolor=\"#00000000\"];\n", myNodeId,
+ resolveFeatureName(splitFeature, featureNames), Double.toString(splitValue)));
+ } else {
+ throw new IllegalStateException("Unsupported attribute type: "
+ + splitFeatureType);
+ }
+
+ if (myNodeId != parentNodeId) {
+ builder.append(' ').append(parentNodeId).append(" -> ").append(myNodeId);
+ if (parentNodeId == 0) {//only draw edge label on top
+ if (myNodeId == 1) {
+ builder.append(" [labeldistance=2.5, labelangle=45, headlabel=\"True\"]");
+ } else {
+ builder.append(" [labeldistance=2.5, labelangle=-45, headlabel=\"False\"]");
+ }
+ }
+ builder.append(";\n");
+ }
+
+ nodeIdGenerator.addValue(1);
+ trueChild.exportGraphviz(builder, featureNames, outputName, nodeIdGenerator,
+ myNodeId);
+ nodeIdGenerator.addValue(1);
+ falseChild.exportGraphviz(builder, featureNames, outputName, nodeIdGenerator,
+ myNodeId);
+ }
+ }
+
+
@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeInt(splitFeature);
@@ -837,14 +914,14 @@ public final class RegressionTree implements Regression<Vector> {
return _root.predict(x);
}
- public String predictJsCodegen() {
+ public String predictJsCodegen(@Nonnull final String[] featureNames) {
StringBuilder buf = new StringBuilder(1024);
- _root.jsCodegen(buf, 0);
+ _root.exportJavascript(buf, featureNames, 0);
return buf.toString();
}
@Nonnull
- public byte[] predictSerCodegen(boolean compress) throws HiveException {
+ public byte[] serialize(boolean compress) throws HiveException {
try {
if (compress) {
return ObjectUtils.toCompressedBytes(_root);
@@ -858,7 +935,8 @@ public final class RegressionTree implements Regression<Vector> {
}
}
- public static Node deserializeNode(final byte[] serializedObj, final int length,
+ @Nonnull
+ public static Node deserialize(@Nonnull final byte[] serializedObj, final int length,
final boolean compressed) throws HiveException {
final Node root = new Node();
try {
@@ -876,4 +954,8 @@ public final class RegressionTree implements Regression<Vector> {
return root;
}
+ @Override
+ public String toString() {
+ return _root == null ? "" : predictJsCodegen(null);
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/core/src/main/java/hivemall/smile/tools/TreeExportUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/tools/TreeExportUDF.java b/core/src/main/java/hivemall/smile/tools/TreeExportUDF.java
new file mode 100644
index 0000000..7d509ad
--- /dev/null
+++ b/core/src/main/java/hivemall/smile/tools/TreeExportUDF.java
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.smile.tools;
+
+import hivemall.UDFWithOptions;
+import hivemall.smile.classification.DecisionTree;
+import hivemall.smile.regression.RegressionTree;
+import hivemall.smile.utils.SmileExtUtils;
+import hivemall.utils.codec.Base91;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.mutable.MutableInt;
+
+import java.util.Arrays;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.UDFType;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
+import org.apache.hadoop.io.Text;
+
+@Description(
+ name = "tree_export",
+ value = "_FUNC_(string model, const string options, optional array<string> featureNames=null, optional array<string> classNames=null)"
+ + " - exports a Decision Tree model as javascript/dot]")
+@UDFType(deterministic = true, stateful = false)
+public final class TreeExportUDF extends UDFWithOptions {
+
+ private transient Evaluator evaluator;
+
+ private transient StringObjectInspector modelOI;
+ @Nullable
+ private transient ListObjectInspector featureNamesOI;
+ @Nullable
+ private transient ListObjectInspector classNamesOI;
+
+ @Override
+ protected Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("t", "type", true,
+ "Type of output [default: js, javascript/js, graphvis/dot");
+ opts.addOption("r", "regression", false, "Is regression tree or not");
+ opts.addOption("output_name", "outputName", true, "output name [default: predicted]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(@Nonnull String opts) throws UDFArgumentException {
+ CommandLine cl = parseOptions(opts);
+
+ OutputType outputType = OutputType.resolve(cl.getOptionValue("type"));
+ boolean regression = cl.hasOption("regression");
+ String outputName = cl.getOptionValue("output_name", "predicted");
+ this.evaluator = new Evaluator(outputType, outputName, regression);
+
+ return cl;
+ }
+
+ @Override
+ public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ final int argLen = argOIs.length;
+ if (argLen < 2 || argLen > 4) {
+ throw new UDFArgumentException("_FUNC_ takes 2~4 arguments: " + argLen);
+ }
+
+ this.modelOI = HiveUtils.asStringOI(argOIs[0]);
+
+ String options = HiveUtils.getConstString(argOIs[1]);
+ processOptions(options);
+
+ if (argLen >= 3) {
+ this.featureNamesOI = HiveUtils.asListOI(argOIs[2]);
+ if (!HiveUtils.isStringOI(featureNamesOI.getListElementObjectInspector())) {
+ throw new UDFArgumentException("_FUNC_ expected array<string> for featureNames: "
+ + featureNamesOI.getTypeName());
+ }
+ if (argLen == 4) {
+ this.classNamesOI = HiveUtils.asListOI(argOIs[3]);
+ if (!HiveUtils.isStringOI(classNamesOI.getListElementObjectInspector())) {
+ throw new UDFArgumentException("_FUNC_ expected array<string> for classNames: "
+ + classNamesOI.getTypeName());
+ }
+ }
+ }
+
+ return PrimitiveObjectInspectorFactory.writableStringObjectInspector;
+ }
+
+ @Override
+ public Object evaluate(DeferredObject[] arguments) throws HiveException {
+ Object arg0 = arguments[0].get();
+ if (arg0 == null) {
+ return null;
+ }
+ Text model = modelOI.getPrimitiveWritableObject(arg0);
+
+ String[] featureNames = null, classNames = null;
+ if (arguments.length >= 3) {
+ featureNames = HiveUtils.asStringArray(arguments[2], featureNamesOI);
+ if (arguments.length >= 4) {
+ classNames = HiveUtils.asStringArray(arguments[3], classNamesOI);
+ }
+ }
+
+ try {
+ return evaluator.export(model, featureNames, classNames);
+ } catch (HiveException he) {
+ throw he;
+ } catch (Throwable e) {
+ throw new HiveException(e);
+ }
+ }
+
+ @Override
+ public String getDisplayString(String[] children) {
+ return "tree_export(" + Arrays.toString(children) + ")";
+ }
+
+ public enum OutputType {
+ javascript, graphvis;
+
+ @Nonnull
+ public static OutputType resolve(@Nonnull String name) throws UDFArgumentException {
+ if ("js".equalsIgnoreCase(name) || "javascript".equalsIgnoreCase(name)) {
+ return javascript;
+ } else if ("dot".equalsIgnoreCase(name) || "graphvis".equalsIgnoreCase(name)) {
+ return graphvis;
+ } else {
+ throw new UDFArgumentException(
+ "Please provide a valid `-type` option from [javascript, graphvis]: " + name);
+ }
+ }
+ }
+
+ public static class Evaluator {
+
+ @Nonnull
+ private final OutputType outputType;
+ @Nonnull
+ private final String outputName;
+ private final boolean regression;
+
+ public Evaluator(@Nonnull OutputType outputType, @Nonnull String outputName,
+ boolean regression) {
+ this.outputType = outputType;
+ this.outputName = outputName;
+ this.regression = regression;
+ }
+
+ @Nonnull
+ public Text export(@Nonnull Text model, @Nullable String[] featureNames,
+ @Nullable String[] classNames) throws HiveException {
+ int length = model.getLength();
+ byte[] b = model.getBytes();
+ b = Base91.decode(b, 0, length);
+
+ final String exported;
+ if (regression) {
+ exported = exportRegressor(b, featureNames);
+ } else {
+ exported = exportClassifier(b, featureNames, classNames);
+ }
+ return new Text(exported);
+ }
+
+ @Nonnull
+ private String exportClassifier(@Nonnull byte[] b, @Nullable String[] featureNames,
+ @Nullable String[] classNames) throws HiveException {
+ final DecisionTree.Node node = DecisionTree.deserialize(b, b.length, true);
+
+ final StringBuilder buf = new StringBuilder(8192);
+ switch (outputType) {
+ case javascript: {
+ node.exportJavascript(buf, featureNames, classNames, 0);
+ break;
+ }
+ case graphvis: {
+ buf.append("digraph Tree {\n node [shape=box, style=\"filled, rounded\", color=\"black\", fontname=helvetica];\n edge [fontname=helvetica];\n");
+ double[] colorBrew = (classNames == null) ? null
+ : SmileExtUtils.getColorBrew(classNames.length);
+ node.exportGraphviz(buf, featureNames, classNames, outputName, colorBrew,
+ new MutableInt(0), 0);
+ buf.append("}");
+ break;
+ }
+ default:
+ throw new HiveException("Unsupported outputType: " + outputType);
+ }
+ return buf.toString();
+ }
+
+ @Nonnull
+ private String exportRegressor(@Nonnull byte[] b, @Nullable String[] featureNames)
+ throws HiveException {
+ final RegressionTree.Node node = RegressionTree.deserialize(b, b.length, true);
+
+ final StringBuilder buf = new StringBuilder(8192);
+ switch (outputType) {
+ case javascript: {
+ node.exportJavascript(buf, featureNames, 0);
+ break;
+ }
+ case graphvis: {
+ buf.append("digraph Tree {\n node [shape=box, style=\"filled, rounded\", color=\"black\", fontname=helvetica];\n edge [fontname=helvetica];\n");
+ node.exportGraphviz(buf, featureNames, outputName, new MutableInt(0), 0);
+ buf.append("}");
+ break;
+ }
+ default:
+ throw new HiveException("Unsupported outputType: " + outputType);
+ }
+ return buf.toString();
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
index dc544ae..46b8758 100644
--- a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
+++ b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
@@ -249,7 +249,7 @@ public final class TreePredictUDF extends GenericUDF {
int length = script.getLength();
byte[] b = script.getBytes();
b = Base91.decode(b, 0, length);
- this.cNode = DecisionTree.deserializeNode(b, b.length, true);
+ this.cNode = DecisionTree.deserialize(b, b.length, true);
}
Arrays.fill(result, null);
@@ -287,7 +287,7 @@ public final class TreePredictUDF extends GenericUDF {
int length = script.getLength();
byte[] b = script.getBytes();
b = Base91.decode(b, 0, length);
- this.rNode = RegressionTree.deserializeNode(b, b.length, true);
+ this.rNode = RegressionTree.deserialize(b, b.length, true);
}
Preconditions.checkNotNull(rNode);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java b/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java
index 74a3032..5e27e12 100644
--- a/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java
+++ b/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java
@@ -33,11 +33,13 @@ import hivemall.smile.data.Attribute.NominalAttribute;
import hivemall.smile.data.Attribute.NumericAttribute;
import hivemall.utils.collections.lists.DoubleArrayList;
import hivemall.utils.collections.lists.IntArrayList;
+import hivemall.utils.lang.Preconditions;
import hivemall.utils.lang.mutable.MutableInt;
import hivemall.utils.math.MathUtils;
import java.util.Arrays;
+import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
@@ -394,4 +396,43 @@ public final class SmileExtUtils {
return false;
}
+ @Nonnull
+ public static String resolveFeatureName(final int index, @Nullable final String[] names) {
+ if (names == null) {
+ return "feature#" + index;
+ }
+ if (index >= names.length) {
+ return "feature#" + index;
+ }
+ return names[index];
+ }
+
+ @Nonnull
+ public static String resolveName(final int index, @Nullable final String[] names) {
+ if (names == null) {
+ return String.valueOf(index);
+ }
+ if (index >= names.length) {
+ return String.valueOf(index);
+ }
+ return names[index];
+ }
+
+ /**
+ * Generates an evenly distributed range of hue values in the HSV color scale.
+ *
+ * @return colors
+ */
+ public static double[] getColorBrew(@Nonnegative int n) {
+ Preconditions.checkArgument(n >= 1);
+
+ final double hue_step = 360.d / n;
+
+ final double[] colors = new double[n];
+ for (int i = 0; i < n; i++) {
+ colors[i] = i * hue_step / 360.d;
+ }
+ return colors;
+ }
+
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
index 6c1b0d1..4ed1f12 100644
--- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
@@ -187,6 +187,25 @@ public final class HiveUtils {
return Arrays.asList(ary);
}
+ @Nullable
+ public static String[] asStringArray(@Nonnull final DeferredObject arg,
+ @Nonnull final ListObjectInspector listOI) throws HiveException {
+ Object argObj = arg.get();
+ if (argObj == null) {
+ return null;
+ }
+ List<?> data = listOI.getList(argObj);
+ final int size = data.size();
+ final String[] arr = new String[size];
+ for (int i = 0; i < size; i++) {
+ Object o = data.get(i);
+ if (o != null) {
+ arr[i] = o.toString();
+ }
+ }
+ return arr;
+ }
+
@Nonnull
public static StructObjectInspector asStructOI(@Nonnull final ObjectInspector oi)
throws UDFArgumentException {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java b/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java
index bb6de6b..897da0c 100644
--- a/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java
+++ b/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java
@@ -25,7 +25,10 @@ import hivemall.math.matrix.dense.RowMajorDenseMatrix2d;
import hivemall.math.random.RandomNumberGeneratorFactory;
import hivemall.smile.classification.DecisionTree.Node;
import hivemall.smile.data.Attribute;
+import hivemall.smile.tools.TreeExportUDF.Evaluator;
+import hivemall.smile.tools.TreeExportUDF.OutputType;
import hivemall.smile.utils.SmileExtUtils;
+import hivemall.utils.codec.Base91;
import java.io.BufferedInputStream;
import java.io.IOException;
@@ -36,6 +39,7 @@ import java.text.ParseException;
import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.io.Text;
import org.junit.Assert;
import org.junit.Test;
@@ -106,6 +110,71 @@ public class DecisionTreeTest {
assertEquals(7, error);
}
+ @Test
+ public void testGraphvisOutputIris() throws IOException, ParseException, HiveException {
+ String datasetUrl = "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff";
+ int responseIndex = 4;
+ int numLeafs = 4;
+ boolean dense = true;
+ String outputName = "class";
+ String[] featureNames = new String[] {"sepallength", "sepalwidth", "petallength",
+ "petalwidth"};
+ String[] classNames = new String[] {"setosa", "versicolor", "virginica"};
+
+ debugPrint(graphvisOutput(datasetUrl, responseIndex, numLeafs, dense, featureNames,
+ classNames, outputName));
+
+ featureNames = null;
+ classNames = null;
+ outputName = null;
+ debugPrint(graphvisOutput(datasetUrl, responseIndex, numLeafs, dense, featureNames,
+ classNames, outputName));
+ }
+
+ @Test
+ public void testGraphvisOutputWeather() throws IOException, ParseException, HiveException {
+ String datasetUrl = "https://gist.githubusercontent.com/myui/2c9df50db3de93a71b92/raw/3f6b4ecfd4045008059e1a2d1c4064fb8a3d372a/weather.nominal.arff";
+ int responseIndex = 4;
+ int numLeafs = 3;
+ boolean dense = true;
+ String[] featureNames = new String[] {"outlook", "temperature", "humidity", "windy"};
+ String[] classNames = new String[] {"yes", "no"};
+ String outputName = "play";
+
+ debugPrint(graphvisOutput(datasetUrl, responseIndex, numLeafs, dense, featureNames,
+ classNames, outputName));
+
+ featureNames = null;
+ classNames = null;
+ debugPrint(graphvisOutput(datasetUrl, responseIndex, numLeafs, dense, featureNames,
+ classNames, outputName));
+ }
+
+ private static String graphvisOutput(String datasetUrl, int responseIndex, int numLeafs,
+ boolean dense, String[] featureNames, String[] classNames, String outputName)
+ throws IOException, HiveException, ParseException {
+ URL url = new URL(datasetUrl);
+ InputStream is = new BufferedInputStream(url.openStream());
+
+ ArffParser arffParser = new ArffParser();
+ arffParser.setResponseIndex(responseIndex);
+
+ AttributeDataset ds = arffParser.parse(is);
+ double[][] x = ds.toArray(new double[ds.size()][]);
+ int[] y = ds.toArray(new int[ds.size()]);
+
+ Attribute[] attrs = SmileExtUtils.convertAttributeTypes(ds.attributes());
+ DecisionTree tree = new DecisionTree(attrs, matrix(x, dense), y, numLeafs,
+ RandomNumberGeneratorFactory.createPRNG(31));
+
+ Text model = new Text(Base91.encode(tree.serialize(true)));
+
+ Evaluator eval = new Evaluator(OutputType.graphvis, outputName, false);
+ Text exported = eval.export(model, featureNames, classNames);
+
+ return exported.toString();
+ }
+
private static int run(String datasetUrl, int responseIndex, int numLeafs, boolean dense)
throws IOException, ParseException {
URL url = new URL(datasetUrl);
@@ -185,8 +254,8 @@ public class DecisionTreeTest {
Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
DecisionTree tree = new DecisionTree(attrs, matrix(trainx, true), trainy, 4);
- byte[] b = tree.predictSerCodegen(false);
- Node node = DecisionTree.deserializeNode(b, b.length, false);
+ byte[] b = tree.serialize(false);
+ Node node = DecisionTree.deserialize(b, b.length, false);
assertEquals(tree.predict(x[loocv.test[i]]), node.predict(x[loocv.test[i]]));
}
}
@@ -212,11 +281,11 @@ public class DecisionTreeTest {
Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
DecisionTree tree = new DecisionTree(attrs, matrix(trainx, true), trainy, 4);
- byte[] b1 = tree.predictSerCodegen(true);
- byte[] b2 = tree.predictSerCodegen(false);
+ byte[] b1 = tree.serialize(true);
+ byte[] b2 = tree.serialize(false);
Assert.assertTrue("b1.length = " + b1.length + ", b2.length = " + b2.length,
b1.length < b2.length);
- Node node = DecisionTree.deserializeNode(b1, b1.length, true);
+ Node node = DecisionTree.deserialize(b1, b1.length, true);
assertEquals(tree.predict(x[loocv.test[i]]), node.predict(x[loocv.test[i]]));
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java b/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java
index d682093..578689c 100644
--- a/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java
+++ b/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java
@@ -208,7 +208,7 @@ public class RandomForestClassifierUDTFTest {
Assert.assertNotNull(modelTxt);
byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength());
- DecisionTree.Node node = DecisionTree.deserializeNode(b, b.length, true);
+ DecisionTree.Node node = DecisionTree.deserialize(b, b.length, true);
return node;
}
@@ -257,7 +257,7 @@ public class RandomForestClassifierUDTFTest {
Assert.assertNotNull(modelTxt);
byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength());
- DecisionTree.Node node = DecisionTree.deserializeNode(b, b.length, true);
+ DecisionTree.Node node = DecisionTree.deserialize(b, b.length, true);
return node;
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java b/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java
index eae625d..f3eb5e5 100644
--- a/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java
+++ b/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java
@@ -24,12 +24,18 @@ import hivemall.math.matrix.dense.RowMajorDenseMatrix2d;
import hivemall.math.random.RandomNumberGeneratorFactory;
import hivemall.smile.data.Attribute;
import hivemall.smile.data.Attribute.NumericAttribute;
+import hivemall.smile.tools.TreeExportUDF.Evaluator;
+import hivemall.smile.tools.TreeExportUDF.OutputType;
+import hivemall.utils.codec.Base91;
+import java.io.IOException;
+import java.text.ParseException;
import java.util.Arrays;
import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.io.Text;
import org.junit.Assert;
import org.junit.Test;
@@ -37,6 +43,7 @@ import smile.math.Math;
import smile.validation.LOOCV;
public class RegressionTreeTest {
+ private static final boolean DEBUG = false;
@Test
public void testPredictDense() {
@@ -158,8 +165,8 @@ public class RegressionTreeTest {
int maxLeafs = Integer.MAX_VALUE;
RegressionTree tree = new RegressionTree(attrs, matrix(trainx, true), trainy, maxLeafs);
- byte[] b = tree.predictSerCodegen(true);
- RegressionTree.Node node = RegressionTree.deserializeNode(b, b.length, true);
+ byte[] b = tree.serialize(true);
+ RegressionTree.Node node = RegressionTree.deserialize(b, b.length, true);
double expected = tree.predict(longley[loocv.test[i]]);
double actual = node.predict(longley[loocv.test[i]]);
@@ -168,6 +175,49 @@ public class RegressionTreeTest {
}
}
+ @Test
+ public void testGraphvizOutput() throws HiveException, IOException, ParseException {
+ int maxLeafts = 10;
+ String outputName = "predicted";
+
+ double[][] x = { {234.289, 235.6, 159.0, 107.608, 1947, 60.323},
+ {259.426, 232.5, 145.6, 108.632, 1948, 61.122},
+ {258.054, 368.2, 161.6, 109.773, 1949, 60.171},
+ {284.599, 335.1, 165.0, 110.929, 1950, 61.187},
+ {328.975, 209.9, 309.9, 112.075, 1951, 63.221},
+ {346.999, 193.2, 359.4, 113.270, 1952, 63.639},
+ {365.385, 187.0, 354.7, 115.094, 1953, 64.989},
+ {363.112, 357.8, 335.0, 116.219, 1954, 63.761},
+ {397.469, 290.4, 304.8, 117.388, 1955, 66.019},
+ {419.180, 282.2, 285.7, 118.734, 1956, 67.857},
+ {442.769, 293.6, 279.8, 120.445, 1957, 68.169},
+ {444.546, 468.1, 263.7, 121.950, 1958, 66.513},
+ {482.704, 381.3, 255.2, 123.366, 1959, 68.655},
+ {502.601, 393.1, 251.4, 125.368, 1960, 69.564},
+ {518.173, 480.6, 257.2, 127.852, 1961, 69.331},
+ {554.894, 400.7, 282.7, 130.081, 1962, 70.551}};
+
+ double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8,
+ 112.6, 114.2, 115.7, 116.9};
+
+ debugPrint(graphvisOutput(x, y, maxLeafts, true, null, outputName));
+ }
+
+ private static String graphvisOutput(double[][] x, double[] y, int maxLeafts, boolean dense,
+ String[] featureNames, String outputName) throws IOException, HiveException,
+ ParseException {
+ Attribute[] attrs = new Attribute[x[0].length];
+ Arrays.fill(attrs, new NumericAttribute());
+ RegressionTree tree = new RegressionTree(attrs, matrix(x, dense), y, maxLeafts);
+
+ Text model = new Text(Base91.encode(tree.serialize(true)));
+
+ Evaluator eval = new Evaluator(OutputType.graphvis, outputName, true);
+ Text exported = eval.export(model, featureNames, null);
+
+ return exported.toString();
+ }
+
@Nonnull
private static Matrix matrix(@Nonnull final double[][] x, boolean dense) {
if (dense) {
@@ -182,4 +232,10 @@ public class RegressionTreeTest {
}
}
+ private static void debugPrint(String msg) {
+ if (DEBUG) {
+ System.out.println(msg);
+ }
+ }
+
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java b/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java
index 65feeeb..31713d9 100644
--- a/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java
+++ b/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java
@@ -164,7 +164,7 @@ public class TreePredictUDFTest {
}
private static int evalPredict(DecisionTree tree, double[] x) throws HiveException, IOException {
- byte[] b = tree.predictSerCodegen(true);
+ byte[] b = tree.serialize(true);
byte[] encoded = Base91.encode(b);
Text model = new Text(encoded);
@@ -186,7 +186,7 @@ public class TreePredictUDFTest {
private static double evalPredict(RegressionTree tree, double[] x) throws HiveException,
IOException {
- byte[] b = tree.predictSerCodegen(true);
+ byte[] b = tree.serialize(true);
byte[] encoded = Base91.encode(b);
Text model = new Text(encoded);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/docs/gitbook/SUMMARY.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/SUMMARY.md b/docs/gitbook/SUMMARY.md
index 32b0150..cc7f622 100644
--- a/docs/gitbook/SUMMARY.md
+++ b/docs/gitbook/SUMMARY.md
@@ -92,6 +92,7 @@
* [Perceptron, Passive Aggressive](binaryclass/news20_pa.md)
* [CW, AROW, SCW](binaryclass/news20_scw.md)
* [AdaGradRDA, AdaGrad, AdaDelta](binaryclass/news20_adagrad.md)
+ * [Random Forest](binaryclass/news20_rf.md)
* [KDD2010a tutorial](binaryclass/kdd2010a.md)
* [Data preparation](binaryclass/kdd2010a_dataset.md)
@@ -121,7 +122,7 @@
* [Iris tutorial](multiclass/iris.md)
* [Data preparation](multiclass/iris_dataset.md)
* [SCW](multiclass/iris_scw.md)
- * [RandomForest](multiclass/iris_randomforest.md)
+ * [Random Forest](multiclass/iris_randomforest.md)
## Part VIII - Regression
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/docs/gitbook/binaryclass/news20_rf.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/binaryclass/news20_rf.md b/docs/gitbook/binaryclass/news20_rf.md
new file mode 100644
index 0000000..fd0b475
--- /dev/null
+++ b/docs/gitbook/binaryclass/news20_rf.md
@@ -0,0 +1,90 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+Hivemall Random Forest supports libsvm-like sparse inputs.
+
+> #### Note
+> This feature, i.e., Sparse input support in Random Forest, is supported since Hivemall v0.5-rc.1 or later._
+> [`feature_hashing`](http://hivemall.incubator.apache.org/userguide/ft_engineering/hashing.html#featurehashing-function) function is useful to prepare feature vectors for Random Forest.
+
+<!-- toc -->
+
+## Training
+
+```sql
+drop table rf_model;
+create table rf_model
+as
+select
+ train_randomforest_classifier(
+ features,
+ convert_label(label), -- convert -1/1 to 0/1
+ '-trees 50 -seed 71' -- hyperparameters
+ )
+from
+ train;
+```
+
+> #### Caution
+> label must be in `[0, k)` where `k` is the number of classes.
+
+## Prediction
+
+```sql
+SET hivevar:classification=true;
+
+drop table rf_predicted;
+create table rf_predicted
+as
+SELECT
+ rowid,
+ rf_ensemble(predicted.value, predicted.posteriori, model_weight) as predicted
+ -- rf_ensemble(predicted.value, predicted.posteriori) as predicted -- avoid OOB accuracy (i.e., model_weight)
+FROM (
+ SELECT
+ rowid,
+ m.model_weight,
+ tree_predict(m.model_id, m.model, t.features, ${classification}) as predicted
+ FROM
+ rf_model m
+ LEFT OUTER JOIN -- CROSS JOIN
+ test t
+) t1
+group by
+ rowid
+;
+```
+
+## Evaluation
+
+```sql
+WITH submit as (
+ select
+ convert_label(t.label) as actual,
+ p.predicted.label as predicted
+ from
+ test t
+ JOIN rf_predicted p on (t.rowid = p.rowid)
+)
+select count(1) / 4996.0
+from submit
+where actual = predicted;
+```
+
+> 0.8112489991993594
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/docs/gitbook/binaryclass/titanic_rf.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/binaryclass/titanic_rf.md b/docs/gitbook/binaryclass/titanic_rf.md
index 1a9786e..64502b9 100644
--- a/docs/gitbook/binaryclass/titanic_rf.md
+++ b/docs/gitbook/binaryclass/titanic_rf.md
@@ -47,10 +47,14 @@ ROW FORMAT DELIMITED
FIELDS TERMINATED BY '|'
LINES TERMINATED BY '\n'
STORED AS TEXTFILE LOCATION '/dataset/titanic/train';
+```
+```sh
hadoop fs -rm /dataset/titanic/train/train.csv
awk '{ FPAT="([^,]*)|(\"[^\"]+\")";OFS="|"; } NR >1 {$1=$1;$4=substr($4,2,length($4)-2);print $0}' train.csv | hadoop fs -put - /dataset/titanic/train/train.csv
+```
+```sql
drop table test_raw;
create external table test_raw (
passengerid int,
@@ -69,7 +73,9 @@ ROW FORMAT DELIMITED
FIELDS TERMINATED BY '|'
LINES TERMINATED BY '\n'
STORED AS TEXTFILE LOCATION '/dataset/titanic/test_raw';
+```
+```sh
hadoop fs -rm /dataset/titanic/test_raw/test.csv
awk '{ FPAT="([^,]*)|(\"[^\"]+\")";OFS="|"; } NR >1 {$1=$1;$3=substr($3,2,length($3)-2);print $0}' test.csv | hadoop fs -put - /dataset/titanic/test_raw/test.csv
```
@@ -163,9 +169,8 @@ select
sum(oob_errors) / sum(oob_tests) as oob_err_rate
from
model_rf;
-
-> [137.00242639169272,1194.2140119834373,328.78017188176966,628.2568660509628,200.31275032394072,160.12876797647078,1083.5987543408116,664.1234312561456,422.89449844090393,130.72019667694784] 0.18742985409652077
```
+> [137.00242639169272,1194.2140119834373,328.78017188176966,628.2568660509628,200.31275032394072,160.12876797647078,1083.5987543408116,664.1234312561456,422.89449844090393,130.72019667694784] 0.18742985409652077
# Prediction
@@ -186,16 +191,27 @@ SELECT
FROM (
SELECT
passengerid,
- rf_ensemble(predicted) as predicted
+ -- rf_ensemble(predicted) as predicted
+ -- hivemall v0.5-rc.1 or later
+ rf_ensemble(predicted.value, predicted.posteriori, model_weight) as predicted
+ -- rf_ensemble(predicted.value, predicted.posteriori) as predicted -- avoid OOB accuracy (i.e., model_weight)
FROM (
SELECT
t.passengerid,
-- hivemall v0.4.1-alpha.2 or before
-- tree_predict(p.model, t.features, ${classification}) as predicted
- -- hivemall v0.4.1-alpha.3 or later
- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted
+ -- hivemall v0.4.1-alpha.3 or later
+ -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted
+ -- hivemall v0.5-rc.1 or later
+ p.model_weight,
+ tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
FROM (
- SELECT model_id, model_type, pred_model FROM model_rf
+ SELECT
+ -- model_id, pred_model
+ -- hivemall v0.5-rc.1 or later
+ model_id, model_weight, model
+ FROM
+ model_rf
DISTRIBUTE BY rand(1)
) p
LEFT OUTER JOIN test_rf t
@@ -223,12 +239,49 @@ ORDER BY passengerid ASC;
```sh
hadoop fs -getmerge /user/hive/warehouse/titanic.db/predicted_rf_submit predicted_rf_submit.csv
-
sed -i -e "1i PassengerId,Survived" predicted_rf_submit.csv
```
Accuracy would gives `0.76555` for a Kaggle submission.
+# Graphvis export
+
+> #### Note
+> `tree_export` feature is supported from Hivemall v0.5-rc.1 or later.
+> Better to limit tree depth on training by `-depth` option to plot a Decision Tree.
+
+Hivemall provide `tree_export` to export a decision tree into [Graphviz](http://www.graphviz.org/) or human-readable Javascript format. You can find the usage by issuing the following query:
+
+```
+> select tree_export("","-help");
+
+usage: tree_export(string model, const string options, optional
+ array<string> featureNames=null, optional array<string>
+ classNames=null) - exports a Decision Tree model as javascript/dot]
+ [-help] [-output_name <arg>] [-r] [-t <arg>]
+ -help Show function help
+ -output_name,--outputName <arg> output name [default: predicted]
+ -r,--regression Is regression tree or not
+ -t,--type <arg> Type of output [default: js,
+ javascript/js, graphvis/dot
+```
+
+```sql
+CREATE TABLE model_exported
+ STORED AS ORC tblproperties("orc.compress"="SNAPPY")
+AS
+select
+ model_id,
+ tree_export(model, "-type javascript -output_name survived", array('pclass','name','sex','age','sibsp','parch','ticket','fare','cabin','embarked'), array('no','yes')) as js,
+ tree_export(model, "-type graphvis -output_name survived", array('pclass','name','sex','age','sibsp','parch','ticket','fare','cabin','embarked'), array('no','yes')) as dot
+from
+ model_rf
+-- limit 1
+;
+```
+
+[Here is an example](https://gist.github.com/myui/a83ba3795bad9b278cf8bcc59f946e2c#file-titanic-dot) plotting a decision tree using Graphvis or [Vis.js](http://viz-js.com/).
+
---
# Test by dividing training dataset
@@ -259,8 +312,10 @@ select
sum(oob_errors) / sum(oob_tests) as oob_err_rate
from
model_rf_07;
+```
> [116.12055542977338,960.8569891444097,291.08765260103837,469.74671636586226,163.721292772701,120.784769882858,847.9769298113661,554.4617571355476,346.3500941757221,97.42593940113392] 0.1838351822503962
+```sql
SET hivevar:classification=true;
SET hive.mapjoin.optimized.hashtable=false;
SET mapred.reduce.tasks=16;
@@ -276,16 +331,27 @@ SELECT
FROM (
SELECT
passengerid,
- rf_ensemble(predicted) as predicted
+ -- rf_ensemble(predicted) as predicted
+ -- hivemall v0.5-rc.1 or later
+ rf_ensemble(predicted.value, predicted.posteriori, model_weight) as predicted
+ -- rf_ensemble(predicted.value, predicted.posteriori) as predicted -- avoid OOB accuracy (i.e., model_weight)
FROM (
SELECT
t.passengerid,
-- hivemall v0.4.1-alpha.2 or before
-- tree_predict(p.model, t.features, ${classification}) as predicted
-- hivemall v0.4.1-alpha.3 or later
- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted
+ -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted
+ -- hivemall v0.5-rc.1 or later
+ p.model_weight,
+ tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
FROM (
- SELECT model_id, model_type, pred_model FROM model_rf_07
+ SELECT
+ -- model_id, model_type, pred_model
+ -- hivemall v0.5-rc.1 or later
+ model_id, model_weight, model
+ FROM
+ model_rf_07
DISTRIBUTE BY rand(1)
) p
LEFT OUTER JOIN test_rf_03 t
@@ -306,13 +372,16 @@ from
;
select count(1) from test_rf_03;
+```
> 260
+
+```sql
set hivevar:testcnt=260;
select count(1)/${testcnt} as accuracy
from rf_submit_03
where actual = predicted;
-
-> 0.8
```
+> 0.8153846153846154
+
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/docs/gitbook/ft_engineering/hashing.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/ft_engineering/hashing.md b/docs/gitbook/ft_engineering/hashing.md
index baa4cd4..8a08b8c 100644
--- a/docs/gitbook/ft_engineering/hashing.md
+++ b/docs/gitbook/ft_engineering/hashing.md
@@ -28,40 +28,54 @@ Find the differences in the following examples.
```sql
select feature_hashing('aaa');
+```
> 4063537
+```sql
select feature_hashing('aaa','-features 3');
+```
> 2
+```sql
select feature_hashing(array('aaa','bbb'));
+```
> ["4063537","8459207"]
+```sql
select feature_hashing(array('aaa','bbb'),'-features 10');
+```
> ["7","1"]
+```sql
select feature_hashing(array('aaa:1.0','aaa','bbb:2.0'));
+```
> ["4063537:1.0","4063537","8459207:2.0"]
+```sql
select feature_hashing(array(1,2,3));
+```
> ["11293631","3322224","4331412"]
+```sql
select feature_hashing(array('1','2','3'));
+```
> ["11293631","3322224","4331412"]
+```sql
select feature_hashing(array('1:0.1','2:0.2','3:0.3'));
+```
> ["11293631:0.1","3322224:0.2","4331412:0.3"]
+```sql
select feature_hashing(features), features from training_fm limit 2;
-
+```
> ["1803454","6630176"] ["userid#5689","movieid#3072"]
> ["1828616","6238429"] ["userid#4505","movieid#2331"]
+```sql
select feature_hashing(array("userid#4505:3.3","movieid#2331:4.999", "movieid#2331"));
-
-> ["1828616:3.3","6238429:4.999","6238429"]
```
-
-_Note: The hash value is starting from 1 and 0 is system reserved for a bias clause. The default number of features are 16777217 (2^24). You can control the number of features by `-num_features` (or `-features`) option._
+> ["1828616:3.3","6238429:4.999","6238429"]
```sql
select feature_hashing(null,'-help');
@@ -74,49 +88,50 @@ usage: feature_hashing(array<string> features [, const string options]) -
-help Show function help
```
+> #### Note
+> The hash value is starting from 1 and 0 is system reserved for a bias clause. The default number of features are 16777217 (2^24).
+> You can control the number of features by `-num_features` (or `-features`) option.
+
## `mhash` function
```sql
describe function extended mhash;
-> mhash(string word) returns a murmurhash3 INT value starting from 1
```
+> mhash(string word) returns a murmurhash3 INT value starting from 1
```sql
-
select mhash('aaa');
-> 4063537
```
+> 4063537
_Note: The default number of features are `16777216 (2^24)`._
```sql
set hivevar:num_features=16777216;
-
select mhash('aaa',${num_features});
->4063537
```
+>4063537
_Note: `mhash` returns a `+1'd` murmurhash3 value starting from 1. Never returns 0 (It's a system reserved number)._
```sql
set hivevar:num_features=1;
-
select mhash('aaa',${num_features});
-> 1
```
+> 1
_Note: `mhash` does not considers feature values._
```sql
select mhash('aaa:2.0');
-> 2746618
```
+> 2746618
_Note: `mhash` always returns a scalar INT value._
```sql
select mhash(array('aaa','bbb'));
-> 9566153
```
+> 9566153
_Note: `mhash` value of an array is element order-sentitive._
```sql
select mhash(array('bbb','aaa'));
+```
> 3874068
-```
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/docs/gitbook/multiclass/iris_dataset.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/multiclass/iris_dataset.md b/docs/gitbook/multiclass/iris_dataset.md
index e67737e..8dae7c9 100644
--- a/docs/gitbook/multiclass/iris_dataset.md
+++ b/docs/gitbook/multiclass/iris_dataset.md
@@ -126,13 +126,13 @@ select rand(${rand_seed}) as rnd, * from iris_scaled;
-- 80% for training
create table train80p as
-select * from iris_shuffled
+select * from iris_shuffled
order by rnd DESC
limit 120;
-- 20% for testing
create table test20p as
-select * from iris_shuffled
+select * from iris_shuffled
order by rnd ASC
limit 30;
@@ -159,64 +159,3 @@ select
from
train80p;
```
-
-# Training (multiclass classification)
-
-```sql
-create table model_scw1 as
-select
- label,
- feature,
- argmin_kld(weight, covar) as weight
-from
- (select
- train_multiclass_scw(features, label) as (label, feature, weight, covar)
- from
- training_x10
- ) t
-group by label, feature;
-```
-
-# Predict
-
-```sql
-create or replace view predict_scw1
-as
-select
- rowid,
- m.col0 as score,
- m.col1 as label
-from (
-select
- rowid,
- maxrow(score, label) as m
-from (
- select
- t.rowid,
- m.label,
- sum(m.weight * t.value) as score
- from
- test20p_exploded t LEFT OUTER JOIN
- model_scw1 m ON (t.feature = m.feature)
- group by
- t.rowid, m.label
-) t1
-group by rowid
-) t2;
-```
-
-# Evaluation
-
-```sql
-create or replace view eval_scw1 as
-select
- t.label as actual,
- p.label as predicted
-from
- test20p t JOIN predict_scw1 p
- on (t.rowid = p.rowid);
-
-select count(1)/30 from eval_scw1
-where actual = predicted;
-```
-> 0.9666666666666667
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/docs/gitbook/multiclass/iris_randomforest.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/multiclass/iris_randomforest.md b/docs/gitbook/multiclass/iris_randomforest.md
index 4b0750c..d0e8e8c 100644
--- a/docs/gitbook/multiclass/iris_randomforest.md
+++ b/docs/gitbook/multiclass/iris_randomforest.md
@@ -89,7 +89,7 @@ from
```sql
CREATE TABLE model
-STORED AS SEQUENCEFILE
+ STORED AS SEQUENCEFILE
AS
select
train_randomforest_classifier(features, label)
@@ -100,60 +100,72 @@ select
from
training;
```
-*Note: The default TEXTFILE should not be used for model table when using Javascript output through "-output javascript" option.*
+> #### Caution
+> The default `TEXTFILE` should not be used for model table when using Javascript output through `-output javascript` option.
+
+```sql
+hive> desc extended model;
```
-hive> desc model;
-model_id int
-model_type int
-pred_model string
-var_importance array<double>
-oob_errors int
-oob_tests int
-```
+
+| col_name | data_type
+|:-:|:-:|
+| model_id | string |
+| model_weight | double |
+| model | string |
+| var_importance | array<double> |
+| oob_errors | int |
+| oob_tests | int |
+
## Training options
-"-help" option shows usage of the function.
+`-help` option shows usage of the function.
-```
+```sql
select train_randomforest_classifier(features, label, "-help") from training;
> FAILED: UDFArgumentException
-usage: train_randomforest_classifier(double[] features, int label [,
- string options]) - Returns a relation consists of <int model_id,
- int model_type, string pred_model, array<double> var_importance,
- int oob_errors, int oob_tests> [-attrs <arg>] [-depth <arg>]
- [-disable_compression] [-help] [-leafs <arg>] [-output <arg>]
- [-rule <arg>] [-seed <arg>] [-splits <arg>] [-trees <arg>] [-vars
- <arg>]
- -attrs,--attribute_types <arg> Comma separated attribute types (Q for
- quantitative variable and C for
- categorical variable. e.g., [Q,C,Q,C])
- -depth,--max_depth <arg> The maximum number of the tree depth
- [default: Integer.MAX_VALUE]
- -disable_compression Whether to disable compression of the
- output script [default: false]
- -help Show function help
- -leafs,--max_leaf_nodes <arg> The maximum number of leaf nodes
- [default: Integer.MAX_VALUE]
- -output,--output_type <arg> The output type (serialization/ser or
- opscode/vm or javascript/js) [default:
- serialization]
- -rule,--split_rule <arg> Split algorithm [default: GINI, ENTROPY]
- -seed <arg> seed value in long [default: -1
- (random)]
- -splits,--min_split <arg> A node that has greater than or equals
- to `min_split` examples will split
- [default: 2]
- -trees,--num_trees <arg> The number of trees for each task
- [default: 50]
- -vars,--num_variables <arg> The number of random selected features
- [default: ceil(sqrt(x[0].length))].
- int(num_variables * x[0].length) is
- considered if num_variable is (0,1]
+usage: train_randomforest_classifier(array<double|string> features, int
+ label [, const array<double> classWeights, const string options]) -
+ Returns a relation consists of <int model_id, int model_type,
+ string pred_model, array<double> var_importance, int oob_errors,
+ int oob_tests, double weight> [-attrs <arg>] [-depth <arg>] [-help]
+ [-leafs <arg>] [-min_samples_leaf <arg>] [-rule <arg>] [-seed
+ <arg>] [-splits <arg>] [-stratified] [-subsample <arg>] [-trees
+ <arg>] [-vars <arg>]
+ -attrs,--attribute_types <arg> Comma separated attribute types (Q
+ for quantitative variable and C for
+ categorical variable. e.g.,
+ [Q,C,Q,C])
+ -depth,--max_depth <arg> The maximum number of the tree depth
+ [default: Integer.MAX_VALUE]
+ -help Show function help
+ -leafs,--max_leaf_nodes <arg> The maximum number of leaf nodes
+ [default: Integer.MAX_VALUE]
+ -min_samples_leaf <arg> The minimum number of samples in a
+ leaf node [default: 1]
+ -rule,--split_rule <arg> Split algorithm [default: GINI,
+ ENTROPY]
+ -seed <arg> seed value in long [default: -1
+ (random)]
+ -splits,--min_split <arg> A node that has greater than or
+ equals to `min_split` examples will
+ split [default: 2]
+ -stratified,--stratified_sampling Enable Stratified sampling for
+ unbalanced data
+ -subsample <arg> Sampling rate in range (0.0,1.0]
+ -trees,--num_trees <arg> The number of trees for each task
+ [default: 50]
+ -vars,--num_variables <arg> The number of random selected
+ features [default:
+ ceil(sqrt(x[0].length))].
+ int(num_variables * x[0].length) is
+ considered if num_variable is (0,1
```
-*Caution: "-num_trees" controls the number of trees for each task, not the total number of trees.*
+
+> #### Caution
+> `-num_trees` controls the number of trees for each task, not the total number of trees.
### Parallelize Training
@@ -161,7 +173,8 @@ To parallelize RandomForest training, you can use UNION ALL as follows:
```sql
CREATE TABLE model
-STORED AS SEQUENCEFILE
+ STORED AS ORC tblproperties("orc.compress"="SNAPPY")
+ -- STORED AS SEQUENCEFILE
AS
select
train_randomforest_classifier(features, label, '-trees 25')
@@ -186,51 +199,7 @@ select
from
model;
```
-> [2.81010338879605,0.4970357753626371,23.790369091407698,14.315316390235273] 0.05333333333333334
-
-### Output prediction model by Javascipt
-
-```sql
-CREATE TABLE model_javascript
-STORED AS SEQUENCEFILE
-AS
-select train_randomforest_classifier(features, label, "-output_type js -disable_compression")
-from training;
-
-select model from model_javascript limit 1;
-```
-
-```js
-if(x[3] <= 0.5) {
- 0;
-} else {
- if(x[2] <= 4.5) {
- if(x[3] <= 1.5) {
- if(x[0] <= 4.5) {
- 1;
- } else {
- if(x[0] <= 5.5) {
- 1;
- } else {
- if(x[1] <= 2.5) {
- 1;
- } else {
- 1;
- }
- }
- }
- } else {
- 2;
- }
- } else {
- if(x[3] <= 1.5) {
- 2;
- } else {
- 2;
- }
- }
-}
-```
+> [6.837674865013268,4.1317115752776665,24.331571871930226,25.677497925673062] 0.056666666666666664
# Prediction
@@ -239,18 +208,24 @@ set hivevar:classification=true;
set hive.auto.convert.join=true;
set hive.mapjoin.optimized.hashtable=false;
-create table predicted_vm
+create table predicted
as
SELECT
rowid,
- rf_ensemble(predicted) as predicted
+ -- rf_ensemble(predicted) as predicted
+ -- hivemall v0.5-rc.1 or later
+ rf_ensemble(predicted.value, predicted.posteriori, model_weight) as predicted
+ -- rf_ensemble(predicted.value, predicted.posteriori) as predicted -- avoid OOB accuracy (i.e., model_weight)
FROM (
SELECT
rowid,
-- hivemall v0.4.1-alpha.2 and before
-- tree_predict(p.model, t.features, ${classification}) as predicted
-- hivemall v0.4.1 and later
- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted
+ -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted
+ -- hivemall v0.5-rc.1 or later
+ p.model_weight,
+ tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
FROM
model p
LEFT OUTER JOIN -- CROSS JOIN
@@ -260,7 +235,6 @@ group by
rowid
;
```
-_Note: Javascript outputs can be evaluated by `js_tree_predict`._
### Parallelize Prediction
@@ -272,20 +246,29 @@ set hive.auto.convert.join=true;
SET hive.mapjoin.optimized.hashtable=false;
SET mapred.reduce.tasks=8;
-create table predicted_vm
+create table predicted
as
SELECT
rowid,
- rf_ensemble(predicted) as predicted
+ -- rf_ensemble(predicted) as predicted
+ -- hivemall v0.5-rc.1 or later
+ rf_ensemble(predicted.value, predicted.posteriori, model_weight) as predicted
+ -- rf_ensemble(predicted.value, predicted.posteriori) as predicted -- avoid OOB accuracy (i.e., model_weight)
FROM (
SELECT
t.rowid,
-- hivemall v0.4.1-alpha.2 and before
-- tree_predict(p.pred_model, t.features, ${classification}) as predicted
-- hivemall v0.4.1 and later
- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted
+ -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted
+ -- hivemall v0.5-rc.1 or later
+ p.model_weight,
+ tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
FROM (
- SELECT model_id, model_type, pred_model
+ SELECT
+ -- model_id, model_type, pred_model
+ -- hivemall v0.5-rc.1 or later
+ model_id, model_weight, model
FROM model
DISTRIBUTE BY rand(1)
) p
@@ -300,8 +283,10 @@ group by
```sql
select count(1) from training;
+```
> 150
+```sql
set hivevar:total_cnt=150;
WITH t1 as (
@@ -310,7 +295,7 @@ SELECT
t.label as actual,
p.predicted.label as predicted
FROM
- predicted_vm p
+ predicted p
LEFT OUTER JOIN training t ON (t.rowid = p.rowid)
)
SELECT
@@ -321,4 +306,76 @@ WHERE
actual = predicted
;
```
-> 0.9533333333333334
+> 0.98
+
+# Graphvis export
+
+> #### Note
+> `tree_export` feature is supported from Hivemall v0.5-rc.1 or later.
+> Better to limit tree depth on training by `-depth` option to plot a Decision Tree.
+
+Hivemall provide `tree_export` to export a decision tree into [Graphviz](http://www.graphviz.org/) or human-readable Javascript format. You can find the usage by issuing the following query:
+
+```
+> select tree_export("","-help");
+
+usage: tree_export(string model, const string options, optional
+ array<string> featureNames=null, optional array<string>
+ classNames=null) - exports a Decision Tree model as javascript/dot]
+ [-help] [-output_name <arg>] [-r] [-t <arg>]
+ -help Show function help
+ -output_name,--outputName <arg> output name [default: predicted]
+ -r,--regression Is regression tree or not
+ -t,--type <arg> Type of output [default: js,
+ javascript/js, graphvis/dot
+```
+
+```sql
+CREATE TABLE model_exported
+ STORED AS ORC tblproperties("orc.compress"="SNAPPY")
+AS
+select
+ model_id,
+ tree_export(model, "-type javascript", array('sepal_length','sepal_width','petal_length','petak_width'), array('Setosa','Versicolour','Virginica')) as js,
+ tree_export(model, "-type graphvis", array('sepal_length','sepal_width','petal_length','petak_width'), array('Setosa','Versicolour','Virginica')) as dot
+from
+ model
+-- limit 1
+;
+```
+
+```
+digraph Tree {
+ node [shape=box, style="filled, rounded", color="black", fontname=helvetica];
+ edge [fontname=helvetica];
+ 0 [label=<petal_length ≤ 2.599999964237213>, fillcolor="#00000000"];
+ 1 [label=<predicted = Setosa>, fillcolor="0.0000,1.000,1.000", shape=ellipse];
+ 0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"];
+ 2 [label=<petal_length ≤ 4.950000047683716>, fillcolor="#00000000"];
+ 0 -> 2 [labeldistance=2.5, labelangle=-45, headlabel="False"];
+ 3 [label=<petak_width ≤ 1.6500000357627869>, fillcolor="#00000000"];
+ 2 -> 3;
+ 4 [label=<predicted = Versicolour>, fillcolor="0.3333,1.000,1.000", shape=ellipse];
+ 3 -> 4;
+ 5 [label=<sepal_width ≤ 3.100000023841858>, fillcolor="#00000000"];
+ 3 -> 5;
+ 6 [label=<predicted = Virginica>, fillcolor="0.6667,1.000,1.000", shape=ellipse];
+ 5 -> 6;
+ 7 [label=<predicted = Versicolour>, fillcolor="0.3333,1.000,1.000", shape=ellipse];
+ 5 -> 7;
+ 8 [label=<petak_width ≤ 1.75>, fillcolor="#00000000"];
+ 2 -> 8;
+ 9 [label=<petal_length ≤ 5.299999952316284>, fillcolor="#00000000"];
+ 8 -> 9;
+ 10 [label=<predicted = Versicolour>, fillcolor="0.3333,1.000,1.000", shape=ellipse];
+ 9 -> 10;
+ 11 [label=<predicted = Virginica>, fillcolor="0.6667,1.000,1.000", shape=ellipse];
+ 9 -> 11;
+ 12 [label=<predicted = Virginica>, fillcolor="0.6667,1.000,1.000", shape=ellipse];
+ 8 -> 12;
+}
+```
+
+<img src="../resources/images/iris.png" alt="Iris Graphvis output"/>
+
+You can draw a graph by `dot -Tpng iris.dot -o iris.png` or using [Viz.js](http://viz-js.com/).
\ No newline at end of file