You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2021/03/29 07:43:05 UTC

[incubator-hivemall] branch master updated: [HIVEMALL-301] Remove macros and replace them with UDF

This is an automated email from the ASF dual-hosted git repository.

myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git


The following commit(s) were added to refs/heads/master by this push:
     new cd655ec  [HIVEMALL-301] Remove macros and replace them with UDF
cd655ec is described below

commit cd655eca60ffe4f9ee2022137bc42374380907b1
Author: Makoto Yui <my...@apache.org>
AuthorDate: Mon Mar 29 16:42:58 2021 +0900

    [HIVEMALL-301] Remove macros and replace them with UDF
    
    ## What changes were proposed in this pull request?
    
    Remove macros and replace them with UDF
    
    ## What type of PR is it?
    
    Improvement, Refactoring
    
    ## What is the Jira issue?
    
    https://issues.apache.org/jira/browse/HIVEMALL-301
    
    ## How was this patch tested?
    
    manual tests
    
    ## Checklist
    
    - [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit?
    - [x] Did you run system tests on Hive (or Spark)?
    
    Author: Makoto Yui <my...@apache.org>
    
    Closes #232 from myui/HIVEMALL-301-tfidf.
---
 .../main/java/hivemall/ftvec/text/TfIdfUDF.java    | 104 +++++++++++++++++++++
 .../main/java/hivemall/utils/hadoop/HiveUtils.java |  31 ++++++
 docs/gitbook/misc/funcs.md                         |   4 +
 resources/ddl/define-all-as-permanent.hive         |   4 +
 resources/ddl/define-all.hive                      |  31 +-----
 resources/ddl/define-all.spark                     |   3 +
 resources/ddl/define-macros.hive                   |  59 ------------
 7 files changed, 149 insertions(+), 87 deletions(-)

diff --git a/core/src/main/java/hivemall/ftvec/text/TfIdfUDF.java b/core/src/main/java/hivemall/ftvec/text/TfIdfUDF.java
new file mode 100644
index 0000000..ee7a218
--- /dev/null
+++ b/core/src/main/java/hivemall/ftvec/text/TfIdfUDF.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.ftvec.text;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.StringUtils;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.UDFType;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+
+@Description(name = "tfidf",
+        value = "_FUNC_(double termFrequency, long numDocs, const long totalNumDocs) "
+                + "- Return a smoothed TFIDF score in double.")
+@UDFType(deterministic = true, stateful = false)
+public final class TfIdfUDF extends GenericUDF {
+
+    private PrimitiveObjectInspector tfOI;
+    private PrimitiveObjectInspector numDocsOI;
+    private PrimitiveObjectInspector totalNumDocsOI;
+
+    @Nonnull
+    private final DoubleWritable result = new DoubleWritable();
+
+    @Override
+    public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+        if (argOIs.length != 3) {
+            throw new UDFArgumentLengthException(
+                "tfidf takes exactly three arguments but got " + argOIs.length);
+        }
+
+        this.tfOI = HiveUtils.asDoubleCompatibleOI(argOIs, 0);
+        this.numDocsOI = HiveUtils.asIntegerOI(argOIs, 1);
+        this.totalNumDocsOI = HiveUtils.asIntegerOI(argOIs, 2);
+
+        return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
+    }
+
+
+    @Override
+    public Object evaluate(DeferredObject[] arguments) throws HiveException {
+        Object arg0 = getObject(arguments, 0);
+        Object arg1 = getObject(arguments, 1);
+        Object arg2 = getObject(arguments, 2);
+
+        double tf = PrimitiveObjectInspectorUtils.getDouble(arg0, tfOI);
+        // Note: not long but double to avoid long by long division
+        double numDocs = PrimitiveObjectInspectorUtils.getLong(arg1, numDocsOI);
+        double totalNumDocs = PrimitiveObjectInspectorUtils.getLong(arg2, totalNumDocsOI);
+
+        // basic IDF
+        //    idf = log(N/n_t)
+        // IDF with smoothing
+        //    idf = log(N/(1+n_t))+1
+        //    idf = log(N/max(1,n_t))+1 -- avoid zero division by max(1,n_t) and +1 for smoothing
+        double idf = Math.log10(totalNumDocs / Math.max(1.d, numDocs)) + 1.0d;
+        double tfidf = tf * idf;
+        result.set(tfidf);
+        return result;
+    }
+
+    @Nonnull
+    private static Object getObject(@Nonnull final DeferredObject[] arguments,
+            @Nonnegative final int index) throws HiveException {
+        Object obj = arguments[index].get();
+        if (obj == null) {
+            throw new UDFArgumentException(String.format("%d-th argument MUST not be null", index));
+        }
+        return obj;
+    }
+
+    @Override
+    public String getDisplayString(String[] children) {
+        return "tfidf(" + StringUtils.join(children, ',') + ")";
+    }
+
+}
diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
index b82f6d4..91b6ecd 100644
--- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
@@ -1197,6 +1197,37 @@ public final class HiveUtils {
     }
 
     @Nonnull
+    public static PrimitiveObjectInspector asLongCompatibleOI(
+            @Nonnull final ObjectInspector[] argOIs, final int argIndex)
+            throws UDFArgumentException {
+        ObjectInspector argOI = getObjectInspector(argOIs, argIndex);
+        if (argOI.getCategory() != Category.PRIMITIVE) {
+            throw new UDFArgumentTypeException(argIndex,
+                "Only primitive type arguments are accepted but " + argOI.getTypeName()
+                        + " is passed.");
+        }
+        final PrimitiveObjectInspector oi = (PrimitiveObjectInspector) argOI;
+        switch (oi.getPrimitiveCategory()) {
+            case LONG:
+            case INT:
+            case SHORT:
+            case BYTE:
+            case BOOLEAN:
+            case FLOAT:
+            case DOUBLE:
+            case DECIMAL:
+            case STRING:
+            case TIMESTAMP:
+                break;
+            default:
+                throw new UDFArgumentTypeException(argIndex,
+                    "Unexpected type '" + argOI.getTypeName() + "' is passed.");
+        }
+        return oi;
+    }
+
+
+    @Nonnull
     public static PrimitiveObjectInspector asIntegerOI(@Nonnull final ObjectInspector argOI)
             throws UDFArgumentTypeException {
         if (argOI.getCategory() != Category.PRIMITIVE) {
diff --git a/docs/gitbook/misc/funcs.md b/docs/gitbook/misc/funcs.md
index a478b52..b3e1006 100644
--- a/docs/gitbook/misc/funcs.md
+++ b/docs/gitbook/misc/funcs.md
@@ -1048,6 +1048,8 @@ Reference: <a href="https://papers.nips.cc/paper/3848-adaptive-regularization-of
 
 - `tf(string text)` - Return a term frequency in &lt;string, float&gt;
 
+- `tfidf(double termFrequency, long numDocs, const long totalNumDocs)` - Return a smoothed TFIDF score in double.
+
 # Others
 
 - `hivemall_version()` - Returns the version of Hivemall
@@ -1064,3 +1066,5 @@ Reference: <a href="https://papers.nips.cc/paper/3848-adaptive-regularization-of
 
 - `tf(string text)` - Return a term frequency in &lt;string, float&gt;
 
+- `tfidf(double termFrequency, long numDocs, const long totalNumDocs)` - Return a smoothed TFIDF score in double.
+
diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive
index f995d55..c5f2669 100644
--- a/resources/ddl/define-all-as-permanent.hive
+++ b/resources/ddl/define-all-as-permanent.hive
@@ -349,6 +349,9 @@ CREATE FUNCTION tf as 'hivemall.ftvec.text.TermFrequencyUDAF' USING JAR '${hivem
 DROP FUNCTION IF EXISTS bm25;
 CREATE FUNCTION bm25 as 'hivemall.ftvec.text.OkapiBM25UDF' USING JAR '${hivemall_jar}';
 
+DROP FUNCTION IF EXISTS tfidf;
+CREATE FUNCTION tfidf as 'hivemall.ftvec.text.TfIdfUDF' USING JAR '${hivemall_jar}';
+
 --------------------------
 -- Regression functions --
 --------------------------
@@ -920,3 +923,4 @@ CREATE FUNCTION xgboost_predict_one AS 'hivemall.xgboost.XGBoostPredictOneUDTF'
 
 DROP FUNCTION xgboost_predict_triple;
 CREATE FUNCTION xgboost_predict_triple AS 'hivemall.xgboost.XGBoostPredictTripleUDTF' USING JAR '${hivemall_jar}';
+
diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive
index bf9bc7c..8bf36e8 100644
--- a/resources/ddl/define-all.hive
+++ b/resources/ddl/define-all.hive
@@ -345,6 +345,9 @@ create temporary function tf as 'hivemall.ftvec.text.TermFrequencyUDAF';
 drop temporary function if exists bm25;
 create temporary function bm25 as 'hivemall.ftvec.text.OkapiBM25UDF';
 
+drop temporary function if exists tfidf;
+create temporary function tfidf as 'hivemall.ftvec.text.TfIdfUDF';
+
 --------------------------
 -- Regression functions --
 --------------------------
@@ -890,31 +893,3 @@ create temporary function min_by as 'hivemall.tools.aggr.MinByUDAF';
 
 drop temporary function if exists majority_vote;
 create temporary function majority_vote as 'hivemall.tools.aggr.MajorityVoteUDAF';
-
-
---------------------------------------------------------------------------------------------------
--- macros available from hive 0.12.0
--- see https://issues.apache.org/jira/browse/HIVE-2655
-
---------------------
--- General Macros --
---------------------
-
-create temporary macro java_min(x DOUBLE, y DOUBLE)
-reflect("java.lang.Math", "min", x, y);
-
-create temporary macro max2(x DOUBLE, y DOUBLE)
-if(x>y,x,y);
-
-create temporary macro min2(x DOUBLE, y DOUBLE)
-if(x<y,x,y);
-
---------------------------
--- Statistics functions --
---------------------------
-
-create temporary macro idf(df_t DOUBLE, n_docs DOUBLE)
-log(10, n_docs / max2(1,df_t)) + 1.0;
-
-create temporary macro tfidf(tf FLOAT, df_t DOUBLE, n_docs DOUBLE)
-tf * (log(10, n_docs / max2(1,df_t)) + 1.0);
diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark
index 8529134..91c6350 100644
--- a/resources/ddl/define-all.spark
+++ b/resources/ddl/define-all.spark
@@ -348,6 +348,9 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION tf AS 'hivemall.ftvec.text.TermFrequen
 sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS bm25")
 sqlContext.sql("CREATE TEMPORARY FUNCTION bm25 AS 'hivemall.ftvec.text.OkapiBM25UDF'")
 
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS tfidf")
+sqlContext.sql("CREATE TEMPORARY FUNCTION tfidf AS 'hivemall.ftvec.text.TfIdfUDF'")
+
 /**
  * Regression functions
  */
diff --git a/resources/ddl/define-macros.hive b/resources/ddl/define-macros.hive
deleted file mode 100644
index ff36a44..0000000
--- a/resources/ddl/define-macros.hive
+++ /dev/null
@@ -1,59 +0,0 @@
------------------------------------------------------------------------------
--- Hivemall: Hive scalable Machine Learning Library
------------------------------------------------------------------------------
-
--- macros available from hive 0.12.0
--- see https://issues.apache.org/jira/browse/HIVE-2655
-
---------------------
--- General Macros --
---------------------
-
-create temporary macro java_min(x DOUBLE, y DOUBLE)
-reflect("java.lang.Math", "min", x, y);
-
-create temporary macro max2(x DOUBLE, y DOUBLE)
-if(x>y,x,y);
-
-create temporary macro min2(x DOUBLE, y DOUBLE)
-if(x<y,x,y);
-
---------------------------
--- Statistics functions --
---------------------------
-
-create temporary macro idf(df_t DOUBLE, n_docs DOUBLE)
-log(10, n_docs / max2(1,df_t)) + 1.0;
-
-create temporary macro tfidf(tf FLOAT, df_t DOUBLE, n_docs DOUBLE)
-tf * (log(10, n_docs / max2(1,df_t)) + 1.0);
-
---------------------------
--- Evaluation functions --
---------------------------
-
--- CAUTION: UDAF in macro is not yet support supported in Hive
-
--- Root Mean Squared Error
--- create temporary macro RMSE(predicted FLOAT, actual FLOAT)
--- sqrt(sum(pow(predicted - actual,2.0))/count(1));
-
--- Mean Squared Error
--- create temporary macro MSE(predicted FLOAT, actual FLOAT)
--- sum(pow(predicted - actual,2.0))/count(1);
-
--- Mean Absolute Error
--- create temporary macro MAE(predicted FLOAT, actual FLOAT)
--- sum(abs(predicted - actual))/count(1);
-
--- sum of squared errors
--- create temporary macro SSE(predicted FLOAT, actual FLOAT)
--- sum(pow(actual - predicted,2.0));
-
--- sum of squared total
--- create temporary macro SST(actual FLOAT, mean_actual FLOAT)
--- sum(pow(actual-mean_actual,2.0));
-
--- coefficient of determination (R^2)
--- create temporary macro R2(predicted FLOAT, actual FLOAT, mean_actual FLOAT)
--- 1 - (SSE(predicted, actual) / SST(actual, mean_actual));
\ No newline at end of file