You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2019/11/26 06:39:40 UTC

[incubator-hivemall] branch master updated: Minor refactoring

This is an automated email from the ASF dual-hosted git repository.

myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git


The following commit(s) were added to refs/heads/master by this push:
     new 7c01f49  Minor refactoring
7c01f49 is described below

commit 7c01f490c8d135101ce8e5d48c8573d70ba3c159
Author: Makoto Yui <my...@apache.org>
AuthorDate: Tue Nov 26 15:39:30 2019 +0900

    Minor refactoring
---
 .../main/java/hivemall/GeneralLearnerBaseUDTF.java |  4 +--
 .../classifier/BinaryOnlineClassifierUDTF.java     |  6 ++--
 .../classifier/ConfidenceWeightedUDTF.java         |  4 +--
 .../factorization/fm/FactorizationMachineUDTF.java |  2 +-
 .../mf/BPRMatrixFactorizationUDTF.java             |  2 +-
 .../mf/OnlineMatrixFactorizationUDTF.java          |  2 +-
 .../GradientTreeBoostingClassifierUDTF.java        |  4 +--
 .../classification/RandomForestClassifierUDTF.java |  6 ++--
 .../regression/RandomForestRegressionUDTF.java     |  6 ++--
 .../smile/tools/RandomForestEnsembleUDAF.java      |  2 +-
 .../java/hivemall/smile/tools/TreeExportUDF.java   | 10 +++----
 .../java/hivemall/smile/tools/TreePredictUDF.java  |  6 ++--
 .../hivemall/smile/tools/TreePredictUDFv1.java     | 10 +++----
 .../main/java/hivemall/utils/hadoop/HiveUtils.java | 35 ++++++++++++++++++++++
 14 files changed, 67 insertions(+), 32 deletions(-)

diff --git a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java
index a251840..af7648e 100644
--- a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java
+++ b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java
@@ -144,9 +144,9 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF {
             showHelp(
                 "_FUNC_ takes two or three arguments: List<Int|BigInt|Text> features, float target [, constant string options]");
         }
-        this.featureListOI = HiveUtils.asListOI(argOIs[0]);
+        this.featureListOI = HiveUtils.asListOI(argOIs, 0);
         this.featureType = getFeatureType(featureListOI);
-        this.targetOI = HiveUtils.asDoubleCompatibleOI(argOIs[1]);
+        this.targetOI = HiveUtils.asDoubleCompatibleOI(argOIs, 1);
 
         processOptions(argOIs);
 
diff --git a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java
index ffe3186..aad4b6d 100644
--- a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java
+++ b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java
@@ -76,11 +76,11 @@ public abstract class BinaryOnlineClassifierUDTF extends LearnerBaseUDTF {
     @Override
     public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
         if (argOIs.length < 2) {
-            throw new UDFArgumentException(getClass().getSimpleName()
-                    + " takes 2 arguments: List<Int|BigInt|Text> features, int label [, constant string options]");
+            showHelp(
+                "_FUNC_ takes 2 arguments: List<Int|BigInt|Text> features, int label [, constant string options]");
         }
         PrimitiveObjectInspector featureInputOI = processFeaturesOI(argOIs[0]);
-        this.labelOI = HiveUtils.asIntCompatibleOI(argOIs[1]);
+        this.labelOI = HiveUtils.asIntCompatibleOI(argOIs, 1);
 
         processOptions(argOIs);
 
diff --git a/core/src/main/java/hivemall/classifier/ConfidenceWeightedUDTF.java b/core/src/main/java/hivemall/classifier/ConfidenceWeightedUDTF.java
index 98c35c5..339cfb6 100644
--- a/core/src/main/java/hivemall/classifier/ConfidenceWeightedUDTF.java
+++ b/core/src/main/java/hivemall/classifier/ConfidenceWeightedUDTF.java
@@ -56,8 +56,8 @@ public final class ConfidenceWeightedUDTF extends BinaryOnlineClassifierUDTF {
     public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
         final int numArgs = argOIs.length;
         if (numArgs != 2 && numArgs != 3) {
-            throw new UDFArgumentException(
-                "ConfidenceWeightedUDTF takes 2 or 3 arguments: List<String|Int|BitInt> features, Int label [, constant String options]");
+            showHelp(
+                "_FUNC_ takes 2 or 3 arguments: List<String|Int|BitInt> features, Int label [, constant String options]");
         }
 
         return super.initialize(argOIs);
diff --git a/core/src/main/java/hivemall/factorization/fm/FactorizationMachineUDTF.java b/core/src/main/java/hivemall/factorization/fm/FactorizationMachineUDTF.java
index a837239..f993820 100644
--- a/core/src/main/java/hivemall/factorization/fm/FactorizationMachineUDTF.java
+++ b/core/src/main/java/hivemall/factorization/fm/FactorizationMachineUDTF.java
@@ -202,7 +202,7 @@ public class FactorizationMachineUDTF extends UDTFWithOptions {
 
         CommandLine cl = null;
         if (argOIs.length >= 3) {
-            String rawArgs = HiveUtils.getConstString(argOIs[2]);
+            String rawArgs = HiveUtils.getConstString(argOIs, 2);
             cl = parseOptions(rawArgs);
             params.processOptions(cl);
         }
diff --git a/core/src/main/java/hivemall/factorization/mf/BPRMatrixFactorizationUDTF.java b/core/src/main/java/hivemall/factorization/mf/BPRMatrixFactorizationUDTF.java
index 3d594c8..3a987a8 100644
--- a/core/src/main/java/hivemall/factorization/mf/BPRMatrixFactorizationUDTF.java
+++ b/core/src/main/java/hivemall/factorization/mf/BPRMatrixFactorizationUDTF.java
@@ -195,7 +195,7 @@ public final class BPRMatrixFactorizationUDTF extends UDTFWithOptions implements
         double convergenceRate = 0.005d;
 
         if (argOIs.length >= 4) {
-            String rawArgs = HiveUtils.getConstString(argOIs[3]);
+            String rawArgs = HiveUtils.getConstString(argOIs, 3);
             cl = parseOptions(rawArgs);
 
             if (cl.hasOption("factor")) {
diff --git a/core/src/main/java/hivemall/factorization/mf/OnlineMatrixFactorizationUDTF.java b/core/src/main/java/hivemall/factorization/mf/OnlineMatrixFactorizationUDTF.java
index 3b850a1..7f017e5 100644
--- a/core/src/main/java/hivemall/factorization/mf/OnlineMatrixFactorizationUDTF.java
+++ b/core/src/main/java/hivemall/factorization/mf/OnlineMatrixFactorizationUDTF.java
@@ -141,7 +141,7 @@ public abstract class OnlineMatrixFactorizationUDTF extends UDTFWithOptions
         double convergenceRate = 0.005d;
 
         if (argOIs.length >= 4) {
-            String rawArgs = HiveUtils.getConstString(argOIs[3]);
+            String rawArgs = HiveUtils.getConstString(argOIs, 3);
             cl = parseOptions(rawArgs);
             if (cl.hasOption("factors")) {
                 this.factor = Primitives.parseInt(cl.getOptionValue("factors"), 10);
diff --git a/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java b/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
index 7b1a036..185a2dd 100644
--- a/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
+++ b/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
@@ -209,7 +209,7 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
                     + argOIs.length);
         }
 
-        ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]);
+        ListObjectInspector listOI = HiveUtils.asListOI(argOIs, 0);
         ObjectInspector elemOI = listOI.getListElementObjectInspector();
         this.featureListOI = listOI;
         if (HiveUtils.isNumberOI(elemOI)) {
@@ -225,7 +225,7 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
                 "_FUNC_ takes double[] or string[] for the first argument: "
                         + listOI.getTypeName());
         }
-        this.labelOI = HiveUtils.asIntCompatibleOI(argOIs[1]);
+        this.labelOI = HiveUtils.asIntCompatibleOI(argOIs, 1);
 
         processOptions(argOIs);
 
diff --git a/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java b/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
index 73bf691..fda4306 100644
--- a/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
+++ b/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
@@ -169,7 +169,7 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
 
         CommandLine cl = null;
         if (argOIs.length >= 3) {
-            String rawArgs = HiveUtils.getConstString(argOIs[2]);
+            String rawArgs = HiveUtils.getConstString(argOIs, 2);
             cl = parseOptions(rawArgs);
 
             trees = Primitives.parseInt(cl.getOptionValue("num_trees"), trees);
@@ -242,7 +242,7 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
                         + argOIs.length);
         }
 
-        ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]);
+        ListObjectInspector listOI = HiveUtils.asListOI(argOIs, 0);
         ObjectInspector elemOI = listOI.getListElementObjectInspector();
         this.featureListOI = listOI;
         if (HiveUtils.isNumberOI(elemOI)) {
@@ -258,7 +258,7 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions {
                 "_FUNC_ takes double[] or string[] for the first argument: "
                         + listOI.getTypeName());
         }
-        this.labelOI = HiveUtils.asIntCompatibleOI(argOIs[1]);
+        this.labelOI = HiveUtils.asIntCompatibleOI(argOIs, 1);
 
         processOptions(argOIs);
 
diff --git a/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java b/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
index 75af989..a8b24ea 100644
--- a/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
+++ b/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
@@ -154,7 +154,7 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
 
         CommandLine cl = null;
         if (argOIs.length >= 3) {
-            String rawArgs = HiveUtils.getConstString(argOIs[2]);
+            String rawArgs = HiveUtils.getConstString(argOIs, 2);
             cl = parseOptions(rawArgs);
 
             trees = Primitives.parseInt(cl.getOptionValue("num_trees"), trees);
@@ -202,7 +202,7 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
                     + argOIs.length);
         }
 
-        ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]);
+        ListObjectInspector listOI = HiveUtils.asListOI(argOIs, 0);
         ObjectInspector elemOI = listOI.getListElementObjectInspector();
         this.featureListOI = listOI;
         if (HiveUtils.isNumberOI(elemOI)) {
@@ -218,7 +218,7 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions {
                 "_FUNC_ takes double[] or string[] for the first argument: "
                         + listOI.getTypeName());
         }
-        this.targetOI = HiveUtils.asDoubleCompatibleOI(argOIs[1]);
+        this.targetOI = HiveUtils.asDoubleCompatibleOI(argOIs, 1);
 
         processOptions(argOIs);
 
diff --git a/core/src/main/java/hivemall/smile/tools/RandomForestEnsembleUDAF.java b/core/src/main/java/hivemall/smile/tools/RandomForestEnsembleUDAF.java
index a63f5c1..28b62da 100644
--- a/core/src/main/java/hivemall/smile/tools/RandomForestEnsembleUDAF.java
+++ b/core/src/main/java/hivemall/smile/tools/RandomForestEnsembleUDAF.java
@@ -122,7 +122,7 @@ public final class RandomForestEnsembleUDAF extends AbstractGenericUDAFResolver
 
             // initialize input
             if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data
-                this.yhatOI = HiveUtils.asIntegerOI(argOIs[0]);
+                this.yhatOI = HiveUtils.asIntegerOI(argOIs, 0);
             } else {// from partial aggregation
                 this.internalMergeOI = (StandardMapObjectInspector) argOIs[0];
                 this.keyOI = HiveUtils.asIntOI(internalMergeOI.getMapKeyObjectInspector());
diff --git a/core/src/main/java/hivemall/smile/tools/TreeExportUDF.java b/core/src/main/java/hivemall/smile/tools/TreeExportUDF.java
index 86389d1..f696002 100644
--- a/core/src/main/java/hivemall/smile/tools/TreeExportUDF.java
+++ b/core/src/main/java/hivemall/smile/tools/TreeExportUDF.java
@@ -83,22 +83,22 @@ public final class TreeExportUDF extends UDFWithOptions {
     public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
         final int argLen = argOIs.length;
         if (argLen < 2 || argLen > 4) {
-            throw new UDFArgumentException("_FUNC_ takes 2~4 arguments: " + argLen);
+            showHelp("tree_export UDF takes 2~4 arguments: " + argLen);
         }
 
-        this.modelOI = HiveUtils.asStringOI(argOIs[0]);
+        this.modelOI = HiveUtils.asStringOI(argOIs, 0);
 
-        String options = HiveUtils.getConstString(argOIs[1]);
+        String options = HiveUtils.getConstString(argOIs, 1);
         processOptions(options);
 
         if (argLen >= 3) {
-            this.featureNamesOI = HiveUtils.asListOI(argOIs[2]);
+            this.featureNamesOI = HiveUtils.asListOI(argOIs, 2);
             if (!HiveUtils.isStringOI(featureNamesOI.getListElementObjectInspector())) {
                 throw new UDFArgumentException("_FUNC_ expected array<string> for featureNames: "
                         + featureNamesOI.getTypeName());
             }
             if (argLen == 4) {
-                this.classNamesOI = HiveUtils.asListOI(argOIs[3]);
+                this.classNamesOI = HiveUtils.asListOI(argOIs, 3);
                 if (!HiveUtils.isStringOI(classNamesOI.getListElementObjectInspector())) {
                     throw new UDFArgumentException("_FUNC_ expected array<string> for classNames: "
                             + classNamesOI.getTypeName());
diff --git a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
index 08360f1..a909c75 100644
--- a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
+++ b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
@@ -92,11 +92,11 @@ public final class TreePredictUDF extends UDFWithOptions {
     @Override
     public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
         if (argOIs.length != 3 && argOIs.length != 4) {
-            throw new UDFArgumentException("tree_predict takes 3 or 4 arguments");
+            showHelp("tree_predict takes 3 or 4 arguments");
         }
 
-        this.modelOI = HiveUtils.asStringOI(argOIs[1]);
-        ListObjectInspector listOI = HiveUtils.asListOI(argOIs[2]);
+        this.modelOI = HiveUtils.asStringOI(argOIs, 1);
+        ListObjectInspector listOI = HiveUtils.asListOI(argOIs, 2);
         this.featureListOI = listOI;
         ObjectInspector elemOI = listOI.getListElementObjectInspector();
         if (HiveUtils.isNumberOI(elemOI)) {
diff --git a/core/src/main/java/hivemall/smile/tools/TreePredictUDFv1.java b/core/src/main/java/hivemall/smile/tools/TreePredictUDFv1.java
index 64d5c3b..a357dd6 100644
--- a/core/src/main/java/hivemall/smile/tools/TreePredictUDFv1.java
+++ b/core/src/main/java/hivemall/smile/tools/TreePredictUDFv1.java
@@ -98,19 +98,19 @@ public final class TreePredictUDFv1 extends GenericUDF {
     @Override
     public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
         if (argOIs.length != 4 && argOIs.length != 5) {
-            throw new UDFArgumentException("_FUNC_ takes 4 or 5 arguments");
+            throw new UDFArgumentException("tree_predict_v1 takes 4 or 5 arguments");
         }
 
-        this.modelTypeOI = HiveUtils.asIntegerOI(argOIs[1]);
-        this.stringOI = HiveUtils.asStringOI(argOIs[2]);
-        ListObjectInspector listOI = HiveUtils.asListOI(argOIs[3]);
+        this.modelTypeOI = HiveUtils.asIntegerOI(argOIs, 1);
+        this.stringOI = HiveUtils.asStringOI(argOIs, 2);
+        ListObjectInspector listOI = HiveUtils.asListOI(argOIs, 3);
         this.featureListOI = listOI;
         ObjectInspector elemOI = listOI.getListElementObjectInspector();
         this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
 
         boolean classification = false;
         if (argOIs.length == 5) {
-            classification = HiveUtils.getConstBoolean(argOIs[4]);
+            classification = HiveUtils.getConstBoolean(argOIs, 4);
         }
         this.classification = classification;
 
diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
index 5047720..26c0e0b 100644
--- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
@@ -685,6 +685,18 @@ public final class HiveUtils {
         return v.get();
     }
 
+
+    public static boolean getConstBoolean(@Nonnull final ObjectInspector[] argOIs,
+            final int argIndex) throws UDFArgumentException {
+        final ObjectInspector oi = getObjectInspector(argOIs, argIndex);
+        if (!isBooleanOI(oi)) {
+            throw new UDFArgumentException("argument must be a Boolean value: "
+                    + TypeInfoUtils.getTypeInfoFromObjectInspector(oi));
+        }
+        BooleanWritable v = getConstValue(oi);
+        return v.get();
+    }
+
     public static int getConstInt(@Nonnull final ObjectInspector oi) throws UDFArgumentException {
         if (!isIntOI(oi)) {
             throw new UDFArgumentException("argument must be a Int value: "
@@ -1195,6 +1207,29 @@ public final class HiveUtils {
     }
 
     @Nonnull
+    public static PrimitiveObjectInspector asIntegerOI(@Nonnull final ObjectInspector[] argOIs,
+            final int argIndex) throws UDFArgumentException {
+        final ObjectInspector argOI = getObjectInspector(argOIs, argIndex);
+        if (argOI.getCategory() != Category.PRIMITIVE) {
+            throw new UDFArgumentTypeException(argIndex,
+                "Only primitive type arguments are accepted but " + argOI.getTypeName()
+                        + " is passed.");
+        }
+        final PrimitiveObjectInspector oi = (PrimitiveObjectInspector) argOI;
+        switch (oi.getPrimitiveCategory()) {
+            case INT:
+            case SHORT:
+            case LONG:
+            case BYTE:
+                break;
+            default:
+                throw new UDFArgumentTypeException(argIndex,
+                    "Unexpected type '" + argOI.getTypeName() + "' is passed.");
+        }
+        return oi;
+    }
+
+    @Nonnull
     public static PrimitiveObjectInspector asDoubleCompatibleOI(
             @Nonnull final ObjectInspector argOI) throws UDFArgumentTypeException {
         if (argOI.getCategory() != Category.PRIMITIVE) {