You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2019/11/26 06:39:40 UTC
[incubator-hivemall] branch master updated: Minor refactoring
This is an automated email from the ASF dual-hosted git repository.
myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git
The following commit(s) were added to refs/heads/master by this push:
new 7c01f49 Minor refactoring
7c01f49 is described below
commit 7c01f490c8d135101ce8e5d48c8573d70ba3c159
Author: Makoto Yui <my...@apache.org>
AuthorDate: Tue Nov 26 15:39:30 2019 +0900
Minor refactoring
---
.../main/java/hivemall/GeneralLearnerBaseUDTF.java | 4 +--
.../classifier/BinaryOnlineClassifierUDTF.java | 6 ++--
.../classifier/ConfidenceWeightedUDTF.java | 4 +--
.../factorization/fm/FactorizationMachineUDTF.java | 2 +-
.../mf/BPRMatrixFactorizationUDTF.java | 2 +-
.../mf/OnlineMatrixFactorizationUDTF.java | 2 +-
.../GradientTreeBoostingClassifierUDTF.java | 4 +--
.../classification/RandomForestClassifierUDTF.java | 6 ++--
.../regression/RandomForestRegressionUDTF.java | 6 ++--
.../smile/tools/RandomForestEnsembleUDAF.java | 2 +-
.../java/hivemall/smile/tools/TreeExportUDF.java | 10 +++----
.../java/hivemall/smile/tools/TreePredictUDF.java | 6 ++--
.../hivemall/smile/tools/TreePredictUDFv1.java | 10 +++----
.../main/java/hivemall/utils/hadoop/HiveUtils.java | 35 ++++++++++++++++++++++
14 files changed, 67 insertions(+), 32 deletions(-)
diff --git a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java
index a251840..af7648e 100644
--- a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java
+++ b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java
@@ -144,9 +144,9 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF {
showHelp(
"_FUNC_ takes two or three arguments: List<Int|BigInt|Text> features, float target [, constant string options]");
}
- this.featureListOI = HiveUtils.asListOI(argOIs[0]);
+ this.featureListOI = HiveUtils.asListOI(argOIs, 0);
this.featureType = getFeatureType(featureListOI);
- this.targetOI = HiveUtils.asDoubleCompatibleOI(argOIs[1]);
+ this.targetOI = HiveUtils.asDoubleCompatibleOI(argOIs, 1);
processOptions(argOIs);
diff --git a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java
index ffe3186..aad4b6d 100644
--- a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java
+++ b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java
@@ -76,11 +76,11 @@ public abstract class BinaryOnlineClassifierUDTF extends LearnerBaseUDTF {
@Override
public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
if (argOIs.length < 2) {
- throw new UDFArgumentException(getClass().getSimpleName()
- + " takes 2 arguments: List<Int|BigInt|Text> features, int label [, constant string options]");
+ showHelp(
+ "_FUNC_ takes 2 arguments: List<Int|BigInt|Text> features, int label [, constant string options]");
}
PrimitiveObjectInspector featureInputOI = processFeaturesOI(argOIs[0]);
- this.labelOI = HiveUtils.asIntCompatibleOI(argOIs[1]);
+ this.labelOI = HiveUtils.asIntCompatibleOI(argOIs, 1);
processOptions(argOIs);
diff --git a/core/src/main/java/hivemall/classifier/ConfidenceWeightedUDTF.java b/core/src/main/java/hivemall/classifier/ConfidenceWeightedUDTF.java
index 98c35c5..339cfb6 100644
--- a/core/src/main/java/hivemall/classifier/ConfidenceWeightedUDTF.java
+++ b/core/src/main/java/hivemall/classifier/ConfidenceWeightedUDTF.java
@@ -56,8 +56,8 @@ public final class ConfidenceWeightedUDTF extends BinaryOnlineClassifierUDTF {
public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
final int numArgs = argOIs.length;
if (numArgs != 2 && numArgs != 3) {
- throw new UDFArgumentException(
- "ConfidenceWeightedUDTF takes 2 or 3 arguments: List<String|Int|BitInt> features, Int label [, constant String options]");
+ showHelp(
+ "_FUNC_ takes 2 or 3 arguments: List<String|Int|BitInt> features, Int label [, constant String options]");
}
return super.initialize(argOIs);
diff --git a/core/src/main/java/hivemall/factorization/fm/FactorizationMachineUDTF.java b/core/src/main/java/hivemall/factorization/fm/FactorizationMachineUDTF.java
index a837239..f993820 100644
--- a/core/src/main/java/hivemall/factorization/fm/FactorizationMachineUDTF.java
+++ b/core/src/main/java/hivemall/factorization/fm/FactorizationMachineUDTF.java
@@ -202,7 +202,7 @@ public class FactorizationMachineUDTF extends UDTFWithOptions {
CommandLine cl = null;
if (argOIs.length >= 3) {
- String rawArgs = HiveUtils.getConstString(argOIs[2]);
+ String rawArgs = HiveUtils.getConstString(argOIs, 2);
cl = parseOptions(rawArgs);
params.processOptions(cl);
}
diff --git a/core/src/main/java/hivemall/factorization/mf/BPRMatrixFactorizationUDTF.java b/core/src/main/java/hivemall/factorization/mf/BPRMatrixFactorizationUDTF.java
index 3d594c8..3a987a8 100644
--- a/core/src/main/java/hivemall/factorization/mf/BPRMatrixFactorizationUDTF.java
+++ b/core/src/main/java/hivemall/factorization/mf/BPRMatrixFactorizationUDTF.java
@@ -195,7 +195,7 @@ public final class BPRMatrixFactorizationUDTF extends UDTFWithOptions implements
double convergenceRate = 0.005d;
if (argOIs.length >= 4) {
- String rawArgs = HiveUtils.getConstString(argOIs[3]);
+ String rawArgs = HiveUtils.getConstString(argOIs, 3);
cl = parseOptions(rawArgs);
if (cl.hasOption("factor")) {
diff --git a/core/src/main/java/hivemall/factorization/mf/OnlineMatrixFactorizationUDTF.java b/core/src/main/java/hivemall/factorization/mf/OnlineMatrixFactorizationUDTF.java
index 3b850a1..7f017e5 100644
--- a/core/src/main/java/hivemall/factorization/mf/OnlineMatrixFactorizationUDTF.java
+++ b/core/src/main/java/hivemall/factorization/mf/OnlineMatrixFactorizationUDTF.java
@@ -141,7 +141,7 @@ public abstract class OnlineMatrixFactorizationUDTF extends UDTFWithOptions
double convergenceRate = 0.005d;
if (argOIs.length >= 4) {
- String rawArgs = HiveUtils.getConstString(argOIs[3]);
+ String rawArgs = HiveUtils.getConstString(argOIs, 3);
cl = parseOptions(rawArgs);
if (cl.hasOption("factors")) {
this.factor = Primitives.parseInt(cl.getOptionValue("factors"), 10);
diff --git a/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java b/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
index 7b1a036..185a2dd 100644
--- a/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
+++ b/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
@@ -209,7 +209,7 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
+ argOIs.length);
}
- ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]);
+ ListObjectInspector listOI = HiveUtils.asListOI(argOIs, 0);
ObjectInspector elemOI = listOI.getListElementObjectInspector();
this.featureListOI = listOI;
if (HiveUtils.isNumberOI(elemOI)) {
@@ -225,7 +225,7 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
"_FUNC_ takes double[] or string[] for the first argument: "
+ listOI.getTypeName());
}
- this.labelOI = HiveUtils.asIntCompatibleOI(argOIs[1]);
+ this.labelOI = HiveUtils.asIntCompatibleOI(argOIs, 1);
processOptions(argOIs);
diff --git a/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java b/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
index 73bf691..fda4306 100644
--- a/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
+++ b/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
@@ -169,7 +169,7 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
CommandLine cl = null;
if (argOIs.length >= 3) {
- String rawArgs = HiveUtils.getConstString(argOIs[2]);
+ String rawArgs = HiveUtils.getConstString(argOIs, 2);
cl = parseOptions(rawArgs);
trees = Primitives.parseInt(cl.getOptionValue("num_trees"), trees);
@@ -242,7 +242,7 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
+ argOIs.length);
}
- ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]);
+ ListObjectInspector listOI = HiveUtils.asListOI(argOIs, 0);
ObjectInspector elemOI = listOI.getListElementObjectInspector();
this.featureListOI = listOI;
if (HiveUtils.isNumberOI(elemOI)) {
@@ -258,7 +258,7 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
"_FUNC_ takes double[] or string[] for the first argument: "
+ listOI.getTypeName());
}
- this.labelOI = HiveUtils.asIntCompatibleOI(argOIs[1]);
+ this.labelOI = HiveUtils.asIntCompatibleOI(argOIs, 1);
processOptions(argOIs);
diff --git a/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java b/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
index 75af989..a8b24ea 100644
--- a/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
+++ b/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
@@ -154,7 +154,7 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
CommandLine cl = null;
if (argOIs.length >= 3) {
- String rawArgs = HiveUtils.getConstString(argOIs[2]);
+ String rawArgs = HiveUtils.getConstString(argOIs, 2);
cl = parseOptions(rawArgs);
trees = Primitives.parseInt(cl.getOptionValue("num_trees"), trees);
@@ -202,7 +202,7 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
+ argOIs.length);
}
- ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]);
+ ListObjectInspector listOI = HiveUtils.asListOI(argOIs, 0);
ObjectInspector elemOI = listOI.getListElementObjectInspector();
this.featureListOI = listOI;
if (HiveUtils.isNumberOI(elemOI)) {
@@ -218,7 +218,7 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
"_FUNC_ takes double[] or string[] for the first argument: "
+ listOI.getTypeName());
}
- this.targetOI = HiveUtils.asDoubleCompatibleOI(argOIs[1]);
+ this.targetOI = HiveUtils.asDoubleCompatibleOI(argOIs, 1);
processOptions(argOIs);
diff --git a/core/src/main/java/hivemall/smile/tools/RandomForestEnsembleUDAF.java b/core/src/main/java/hivemall/smile/tools/RandomForestEnsembleUDAF.java
index a63f5c1..28b62da 100644
--- a/core/src/main/java/hivemall/smile/tools/RandomForestEnsembleUDAF.java
+++ b/core/src/main/java/hivemall/smile/tools/RandomForestEnsembleUDAF.java
@@ -122,7 +122,7 @@ public final class RandomForestEnsembleUDAF extends AbstractGenericUDAFResolver
// initialize input
if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data
- this.yhatOI = HiveUtils.asIntegerOI(argOIs[0]);
+ this.yhatOI = HiveUtils.asIntegerOI(argOIs, 0);
} else {// from partial aggregation
this.internalMergeOI = (StandardMapObjectInspector) argOIs[0];
this.keyOI = HiveUtils.asIntOI(internalMergeOI.getMapKeyObjectInspector());
diff --git a/core/src/main/java/hivemall/smile/tools/TreeExportUDF.java b/core/src/main/java/hivemall/smile/tools/TreeExportUDF.java
index 86389d1..f696002 100644
--- a/core/src/main/java/hivemall/smile/tools/TreeExportUDF.java
+++ b/core/src/main/java/hivemall/smile/tools/TreeExportUDF.java
@@ -83,22 +83,22 @@ public final class TreeExportUDF extends UDFWithOptions {
public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
final int argLen = argOIs.length;
if (argLen < 2 || argLen > 4) {
- throw new UDFArgumentException("_FUNC_ takes 2~4 arguments: " + argLen);
+ showHelp("tree_export UDF takes 2~4 arguments: " + argLen);
}
- this.modelOI = HiveUtils.asStringOI(argOIs[0]);
+ this.modelOI = HiveUtils.asStringOI(argOIs, 0);
- String options = HiveUtils.getConstString(argOIs[1]);
+ String options = HiveUtils.getConstString(argOIs, 1);
processOptions(options);
if (argLen >= 3) {
- this.featureNamesOI = HiveUtils.asListOI(argOIs[2]);
+ this.featureNamesOI = HiveUtils.asListOI(argOIs, 2);
if (!HiveUtils.isStringOI(featureNamesOI.getListElementObjectInspector())) {
throw new UDFArgumentException("_FUNC_ expected array<string> for featureNames: "
+ featureNamesOI.getTypeName());
}
if (argLen == 4) {
- this.classNamesOI = HiveUtils.asListOI(argOIs[3]);
+ this.classNamesOI = HiveUtils.asListOI(argOIs, 3);
if (!HiveUtils.isStringOI(classNamesOI.getListElementObjectInspector())) {
throw new UDFArgumentException("_FUNC_ expected array<string> for classNames: "
+ classNamesOI.getTypeName());
diff --git a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
index 08360f1..a909c75 100644
--- a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
+++ b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
@@ -92,11 +92,11 @@ public final class TreePredictUDF extends UDFWithOptions {
@Override
public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
if (argOIs.length != 3 && argOIs.length != 4) {
- throw new UDFArgumentException("tree_predict takes 3 or 4 arguments");
+ showHelp("tree_predict takes 3 or 4 arguments");
}
- this.modelOI = HiveUtils.asStringOI(argOIs[1]);
- ListObjectInspector listOI = HiveUtils.asListOI(argOIs[2]);
+ this.modelOI = HiveUtils.asStringOI(argOIs, 1);
+ ListObjectInspector listOI = HiveUtils.asListOI(argOIs, 2);
this.featureListOI = listOI;
ObjectInspector elemOI = listOI.getListElementObjectInspector();
if (HiveUtils.isNumberOI(elemOI)) {
diff --git a/core/src/main/java/hivemall/smile/tools/TreePredictUDFv1.java b/core/src/main/java/hivemall/smile/tools/TreePredictUDFv1.java
index 64d5c3b..a357dd6 100644
--- a/core/src/main/java/hivemall/smile/tools/TreePredictUDFv1.java
+++ b/core/src/main/java/hivemall/smile/tools/TreePredictUDFv1.java
@@ -98,19 +98,19 @@ public final class TreePredictUDFv1 extends GenericUDF {
@Override
public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
if (argOIs.length != 4 && argOIs.length != 5) {
- throw new UDFArgumentException("_FUNC_ takes 4 or 5 arguments");
+ throw new UDFArgumentException("tree_predict_v1 takes 4 or 5 arguments");
}
- this.modelTypeOI = HiveUtils.asIntegerOI(argOIs[1]);
- this.stringOI = HiveUtils.asStringOI(argOIs[2]);
- ListObjectInspector listOI = HiveUtils.asListOI(argOIs[3]);
+ this.modelTypeOI = HiveUtils.asIntegerOI(argOIs, 1);
+ this.stringOI = HiveUtils.asStringOI(argOIs, 2);
+ ListObjectInspector listOI = HiveUtils.asListOI(argOIs, 3);
this.featureListOI = listOI;
ObjectInspector elemOI = listOI.getListElementObjectInspector();
this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
boolean classification = false;
if (argOIs.length == 5) {
- classification = HiveUtils.getConstBoolean(argOIs[4]);
+ classification = HiveUtils.getConstBoolean(argOIs, 4);
}
this.classification = classification;
diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
index 5047720..26c0e0b 100644
--- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
@@ -685,6 +685,18 @@ public final class HiveUtils {
return v.get();
}
+
+ public static boolean getConstBoolean(@Nonnull final ObjectInspector[] argOIs,
+ final int argIndex) throws UDFArgumentException {
+ final ObjectInspector oi = getObjectInspector(argOIs, argIndex);
+ if (!isBooleanOI(oi)) {
+ throw new UDFArgumentException("argument must be a Boolean value: "
+ + TypeInfoUtils.getTypeInfoFromObjectInspector(oi));
+ }
+ BooleanWritable v = getConstValue(oi);
+ return v.get();
+ }
+
public static int getConstInt(@Nonnull final ObjectInspector oi) throws UDFArgumentException {
if (!isIntOI(oi)) {
throw new UDFArgumentException("argument must be a Int value: "
@@ -1195,6 +1207,29 @@ public final class HiveUtils {
}
@Nonnull
+ public static PrimitiveObjectInspector asIntegerOI(@Nonnull final ObjectInspector[] argOIs,
+ final int argIndex) throws UDFArgumentException {
+ final ObjectInspector argOI = getObjectInspector(argOIs, argIndex);
+ if (argOI.getCategory() != Category.PRIMITIVE) {
+ throw new UDFArgumentTypeException(argIndex,
+ "Only primitive type arguments are accepted but " + argOI.getTypeName()
+ + " is passed.");
+ }
+ final PrimitiveObjectInspector oi = (PrimitiveObjectInspector) argOI;
+ switch (oi.getPrimitiveCategory()) {
+ case INT:
+ case SHORT:
+ case LONG:
+ case BYTE:
+ break;
+ default:
+ throw new UDFArgumentTypeException(argIndex,
+ "Unexpected type '" + argOI.getTypeName() + "' is passed.");
+ }
+ return oi;
+ }
+
+ @Nonnull
public static PrimitiveObjectInspector asDoubleCompatibleOI(
@Nonnull final ObjectInspector argOI) throws UDFArgumentTypeException {
if (argOI.getCategory() != Category.PRIMITIVE) {