You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2019/06/20 10:35:48 UTC

[incubator-hivemall] branch master updated: [HIVEMALL-258] Add UDF to convert feature/label in Libsvm format

This is an automated email from the ASF dual-hosted git repository.

myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git


The following commit(s) were added to refs/heads/master by this push:
     new 3827b6c  [HIVEMALL-258] Add UDF to convert feature/label in Libsvm format
3827b6c is described below

commit 3827b6caa18c0ea687a0c7b576079e4d0ea9b100
Author: Makoto Yui <my...@apache.org>
AuthorDate: Thu Jun 20 19:35:42 2019 +0900

    [HIVEMALL-258] Add UDF to convert feature/label in Libsvm format
    
    ## What changes were proposed in this pull request?
    
    Add UDF to convert feature/label in Libsvm format
    
    ## What type of PR is it?
    
    Feature
    
    ## What is the Jira issue?
    
    https://issues.apache.org/jira/browse/HIVEMALL-258
    
    ## How was this patch tested?
    
    unit tests and manual tests
    
    ## How to use this feature?
    
    ```sql
    Usage:
     select to_libsvm_format(array('apple:3.4','orange:2.1'))
     > 6284535:3.4 8104713:2.1
     select to_libsvm_format(array('apple:3.4','orange:2.1'), '-features 10')
     > 3:2.1 7:3.4
     select to_libsvm_format(array('7:3.4','3:2.1'), 5.0)
     > 5.0 3:2.1 7:3.4
    ```
    
    ## Checklist
    
    (Please remove this section if not needed; check `x` for YES, blank for NO)
    
    - [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit?
    - [x] Did you run system tests on Hive (or Spark)?
    
    Author: Makoto Yui <my...@apache.org>
    
    Closes #194 from myui/libsvm.
---
 .../hivemall/ftvec/conv/ToLibSVMFormatUDF.java     | 228 +++++++++++++++++++++
 .../hivemall/ftvec/conv/ToLibSVMFormatUDFTest.java |  99 +++++++++
 docs/gitbook/misc/funcs.md                         |  33 ++-
 resources/ddl/define-all-as-permanent.hive         |   3 +
 resources/ddl/define-all.hive                      |   4 +-
 resources/ddl/define-all.spark                     |   3 +
 6 files changed, 364 insertions(+), 6 deletions(-)

diff --git a/core/src/main/java/hivemall/ftvec/conv/ToLibSVMFormatUDF.java b/core/src/main/java/hivemall/ftvec/conv/ToLibSVMFormatUDF.java
new file mode 100644
index 0000000..723cb0b
--- /dev/null
+++ b/core/src/main/java/hivemall/ftvec/conv/ToLibSVMFormatUDF.java
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.ftvec.conv;
+
+import hivemall.UDFWithOptions;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.hashing.MurmurHash3;
+import hivemall.utils.lang.NumberUtils;
+import hivemall.utils.lang.Primitives;
+import hivemall.utils.lang.StringUtils;
+import hivemall.utils.struct.Pair;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.List;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.UDFType;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+
+// @formatter:off
+@Description(name = "to_libsvm_format",
+        value = "_FUNC_(array<string> feautres [, double/integer target, const string options])"
+                + " - Returns a string representation of libsvm",
+                extended = "Usage:\n" + 
+                        " select to_libsvm_format(array('apple:3.4','orange:2.1'))\n" + 
+                        " > 6284535:3.4 8104713:2.1\n" + 
+                        " select to_libsvm_format(array('apple:3.4','orange:2.1'), '-features 10')\n" + 
+                        " > 3:2.1 7:3.4\n" + 
+                        " select to_libsvm_format(array('7:3.4','3:2.1'), 5.0)\n" + 
+                        " > 5.0 3:2.1 7:3.4")
+// @formatter:on
+@UDFType(deterministic = true, stateful = false)
+public final class ToLibSVMFormatUDF extends UDFWithOptions {
+
+    private ListObjectInspector _featuresOI;
+    @Nullable
+    private PrimitiveObjectInspector _targetOI = null;
+    private int _numFeatures = MurmurHash3.DEFAULT_NUM_FEATURES;
+    private StringBuilder _buf;
+
+    @Override
+    protected Options getOptions() {
+        Options opts = new Options();
+        opts.addOption("features", "num_features", true,
+            "The number of features [default: 16777217 (2^24)]");
+        return opts;
+    }
+
+    @Override
+    protected CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException {
+        CommandLine cl = parseOptions(optionValue);
+        this._numFeatures = Primitives.parseInt(cl.getOptionValue("num_features"),
+            MurmurHash3.DEFAULT_NUM_FEATURES);
+        assumeTrue(_numFeatures > 0, "num_features must be greater than 0: " + _numFeatures);
+        return cl;
+    }
+
+    @Override
+    public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+        assumeTrue(argOIs.length >= 1 || argOIs.length <= 3,
+            "to_libsvm_format UDF takes 1~3 arguments");
+
+        this._featuresOI = HiveUtils.asListOI(argOIs[0]);
+        if (argOIs.length == 2) {
+            ObjectInspector argOI1 = argOIs[1];
+            if (HiveUtils.isNumberOI(argOI1)) {
+                this._targetOI = HiveUtils.asNumberOI(argOI1);
+            } else if (HiveUtils.isConstString(argOI1)) { // no target
+                String opts = HiveUtils.getConstString(argOI1);
+                processOptions(opts);
+            } else {
+                throw new UDFArgumentException(
+                    "Unexpected argument type for 2nd argument: " + argOI1.getTypeName());
+            }
+        } else if (argOIs.length == 3) {
+            this._targetOI = HiveUtils.asNumberOI(argOIs[1]);
+            String opts = HiveUtils.getConstString(argOIs[2]);
+            processOptions(opts);
+        }
+
+        this._buf = new StringBuilder();
+
+        return PrimitiveObjectInspectorFactory.javaStringObjectInspector;
+    }
+
+    @Nullable
+    @Override
+    public String evaluate(DeferredObject[] args) throws HiveException {
+        final StringBuilder buf = this._buf;
+        StringUtils.clear(buf);
+
+        Object arg0 = args[0].get();
+        if (arg0 == null) {
+            return null;
+        }
+
+        final int featureSize = _featuresOI.getListLength(arg0);
+        List<Pair<Integer, Double>> features = new ArrayList<>(featureSize);
+        for (int i = 0; i < featureSize; i++) {
+            Object e = _featuresOI.getListElement(arg0, i);
+            if (e == null) {
+                continue;
+            }
+            Pair<Integer, Double> fv = parse(e.toString(), _numFeatures);
+            features.add(fv);
+        }
+        Collections.sort(features, comparator);
+
+        if (_targetOI != null) {
+            Object arg1 = args[1].get();
+            if (arg1 == null) {
+                throw new HiveException("Detected NULL for the 2nd argument");
+            }
+            if (HiveUtils.isIntegerOI(_targetOI)) {
+                int label = HiveUtils.getInt(arg1, _targetOI);
+                buf.append(label);
+            } else {
+                double label = HiveUtils.getDouble(arg1, _targetOI);
+                buf.append(label);
+            }
+            buf.append(' ');
+        }
+        for (int i = 0, size = features.size(); i < size; i++) {
+            if (i != 0) {
+                buf.append(' ');
+            }
+            Pair<Integer, Double> fv = features.get(i);
+            buf.append(fv.getKey().intValue());
+            buf.append(':');
+            buf.append(fv.getValue().doubleValue());
+        }
+
+        return buf.toString();
+    }
+
+    @Nonnull
+    public static Pair<Integer, Double> parse(@Nonnull final String fv,
+            @Nonnegative final int numFeatures) throws UDFArgumentException {
+        final int headPos = fv.indexOf(':');
+        if (headPos == -1) {
+            if (NumberUtils.isDigits(fv)) {
+                final int f;
+                try {
+                    f = Integer.parseInt(fv);
+                } catch (NumberFormatException e) {
+                    throw new UDFArgumentException("Invalid feature value: " + fv);
+                }
+                return new Pair<>(f, 1.d);
+            } else {
+                return new Pair<>(mhash(fv, numFeatures), 1.d);
+            }
+        } else {
+            final int tailPos = fv.lastIndexOf(':');
+            if (headPos != tailPos) {
+                throw new UDFArgumentException("Unsupported feature format: " + fv);
+            }
+            String f = fv.substring(0, headPos);
+            String v = fv.substring(headPos + 1);
+            final double d;
+            try {
+                d = Double.parseDouble(v);
+            } catch (NumberFormatException e) {
+                throw new UDFArgumentException("Invalid feature value: " + fv);
+            }
+            if (NumberUtils.isDigits(f)) {
+                final int i;
+                try {
+                    i = Integer.parseInt(f);
+                } catch (NumberFormatException e) {
+                    throw new UDFArgumentException("Invalid feature value: " + fv);
+                }
+                return new Pair<>(i, d);
+            } else {
+                return new Pair<>(mhash(f, numFeatures), d);
+            }
+        }
+    }
+
+    private static int mhash(@Nonnull final String word, final int numFeatures) {
+        int r = MurmurHash3.murmurhash3_x86_32(word, 0, word.length(), 0x9747b28c) % numFeatures;
+        if (r < 0) {
+            r += numFeatures;
+        }
+        return r + 1;
+    }
+
+    private static final Comparator<Pair<Integer, Double>> comparator =
+            new Comparator<Pair<Integer, Double>>() {
+                @Override
+                public int compare(Pair<Integer, Double> l, Pair<Integer, Double> r) {
+                    return l.getKey().compareTo(r.getKey());
+                }
+            };
+
+    @Override
+    public String getDisplayString(String[] args) {
+        return "to_libsvm_format( " + StringUtils.join(args, ',') + " )";
+    }
+}
diff --git a/core/src/test/java/hivemall/ftvec/conv/ToLibSVMFormatUDFTest.java b/core/src/test/java/hivemall/ftvec/conv/ToLibSVMFormatUDFTest.java
new file mode 100644
index 0000000..6a59058
--- /dev/null
+++ b/core/src/test/java/hivemall/ftvec/conv/ToLibSVMFormatUDFTest.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.ftvec.conv;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredJavaObject;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class ToLibSVMFormatUDFTest {
+
+    @Test
+    public void testFeatureOnly() throws IOException, HiveException {
+        ToLibSVMFormatUDF udf = new ToLibSVMFormatUDF();
+
+        udf.initialize(new ObjectInspector[] {
+                ObjectInspectorFactory.getStandardListObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+                ObjectInspectorUtils.getConstantObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-features 10")});
+
+        Assert.assertEquals("3:2.1 7:3.4", udf.evaluate(new DeferredObject[] {
+                new DeferredJavaObject(Arrays.asList("apple:3.4", "orange:2.1"))}));
+
+        Assert.assertEquals("3:2.1 7:3.4", udf.evaluate(
+            new DeferredObject[] {new DeferredJavaObject(Arrays.asList("7:3.4", "3:2.1"))}));
+
+        udf.close();
+    }
+
+    @Test
+    public void testFeatureAndIntLabel() throws IOException, HiveException {
+        ToLibSVMFormatUDF udf = new ToLibSVMFormatUDF();
+
+        udf.initialize(
+            new ObjectInspector[] {
+                    ObjectInspectorFactory.getStandardListObjectInspector(
+                        PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+                    PrimitiveObjectInspectorFactory.javaIntObjectInspector,
+                    ObjectInspectorUtils.getConstantObjectInspector(
+                        PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+                        "-features 10")});
+
+        Assert.assertEquals("5 3:2.1 7:3.4",
+            udf.evaluate(new DeferredObject[] {
+                    new DeferredJavaObject(Arrays.asList("apple:3.4", "orange:2.1")),
+                    new DeferredJavaObject(5)}));
+
+        udf.close();
+    }
+
+    @Test
+    public void testFeatureAndFloatLabel() throws IOException, HiveException {
+        ToLibSVMFormatUDF udf = new ToLibSVMFormatUDF();
+
+        udf.initialize(
+            new ObjectInspector[] {
+                    ObjectInspectorFactory.getStandardListObjectInspector(
+                        PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+                    PrimitiveObjectInspectorFactory.javaFloatObjectInspector,
+                    ObjectInspectorUtils.getConstantObjectInspector(
+                        PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+                        "-features 10")});
+
+        Assert.assertEquals("5.0 3:2.1 7:3.4",
+            udf.evaluate(
+                new DeferredObject[] {new DeferredJavaObject(Arrays.asList("7:3.4", "3:2.1")),
+                        new DeferredJavaObject(5.f)}));
+
+        udf.close();
+    }
+
+
+
+}
diff --git a/docs/gitbook/misc/funcs.md b/docs/gitbook/misc/funcs.md
index ade9ee3..1b1b280 100644
--- a/docs/gitbook/misc/funcs.md
+++ b/docs/gitbook/misc/funcs.md
@@ -65,13 +65,25 @@ Reference: <a href="https://papers.nips.cc/paper/3848-adaptive-regularization-of
   GROUP BY feature
   ```
 
-- `train_pa1_regr(array<int|bigint|string> features, float target [, constant string options])` - PA-1 regressor that returns a relation consists of `&lt;int|bigint|string&gt; feature, float weight`. Find PA-1 algorithm detail in http://jmlr.csail.mit.edu/papers/volume7/crammer06a/crammer06a.pdf
-
-- `train_pa1a_regr(array<int|bigint|string> features, float target [, constant string options])` - Returns a relation consists of `&lt;int|bigint|string&gt; feature, float weight`.
+- `train_pa1_regr(array<int|bigint|string> features, float target [, constant string options])` - PA-1 regressor that returns a relation consists of `(int|bigint|string) feature, float weight`.
+  ```sql
+  SELECT 
+   feature,
+   avg(weight) as weight
+  FROM 
+   (SELECT 
+       train_pa1_regr(features,label) as (feature,weight)
+    FROM 
+       training_data
+   ) t 
+  GROUP BY feature
+  ```
+Reference: <a href="http://jmlr.csail.mit.edu/papers/volume7/crammer06a/crammer06a.pdf" target="_blank">Koby Crammer et.al., Online Passive-Aggressive Algorithms. Journal of Machine Learning Research, 2006.</a><br/>
+- `train_pa1a_regr(array<int|bigint|string> features, float target [, constant string options])` - Returns a relation consists of `(int|bigint|string) feature, float weight`.
 
-- `train_pa2_regr(array<int|bigint|string> features, float target [, constant string options])` - Returns a relation consists of `&lt;int|bigint|string&gt; feature, float weight`.
+- `train_pa2_regr(array<int|bigint|string> features, float target [, constant string options])` - Returns a relation consists of `(int|bigint|string) feature, float weight`.
 
-- `train_pa2a_regr(array<int|bigint|string> features, float target [, constant string options])` - Returns a relation consists of `&lt;int|bigint|string&gt; feature, float weight`.
+- `train_pa2a_regr(array<int|bigint|string> features, float target [, constant string options])` - Returns a relation consists of `(int|bigint|string) feature, float weight`.
 
 - `train_regressor(list<string|int|bigint> features, double label [, const string options])` - Returns a relation consists of &lt;string|int|bigint feature, float weight&gt;
   ```
@@ -261,6 +273,17 @@ Reference: <a href="https://papers.nips.cc/paper/3848-adaptive-regularization-of
 
 - `to_dense_features(array<string> feature_vector, int dimensions)` - Returns a dense feature in array&lt;float&gt;
 
+- `to_libsvm_format(array<string> feautres [, double/integer target, const string options])` - Returns a string representation of libsvm
+  ```sql
+  Usage:
+   select to_libsvm_format(array('apple:3.4','orange:2.1'))
+   > 6284535:3.4 8104713:2.1
+   select to_libsvm_format(array('apple:3.4','orange:2.1'), '-features 10')
+   > 3:2.1 7:3.4
+   select to_libsvm_format(array('7:3.4','3:2.1'), 5.0)
+   > 5.0 3:2.1 7:3.4
+  ```
+
 - `to_sparse_features(array<float> feature_vector)` - Returns a sparse feature in array&lt;string&gt;
 
 ## Feature hashing
diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive
index 0c836f2..ff20c8c 100644
--- a/resources/ddl/define-all-as-permanent.hive
+++ b/resources/ddl/define-all-as-permanent.hive
@@ -288,6 +288,9 @@ CREATE FUNCTION build_bins as 'hivemall.ftvec.binning.BuildBinsUDAF' USING JAR '
 DROP FUNCTION IF EXISTS feature_binning;
 CREATE FUNCTION feature_binning as 'hivemall.ftvec.binning.FeatureBinningUDF' USING JAR '${hivemall_jar}';
 
+DROP FUNCTION IF EXISTS to_libsvm_format;
+CREATE FUNCTION to_libsvm_format as 'hivemall.ftvec.conv.ToLibSVMFormatUDF' USING JAR '${hivemall_jar}';
+
 --------------------------
 -- feature transformers --
 --------------------------
diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive
index e6f7c0b..0495113 100644
--- a/resources/ddl/define-all.hive
+++ b/resources/ddl/define-all.hive
@@ -284,6 +284,9 @@ create temporary function build_bins as 'hivemall.ftvec.binning.BuildBinsUDAF';
 drop temporary function if exists feature_binning;
 create temporary function feature_binning as 'hivemall.ftvec.binning.FeatureBinningUDF';
 
+drop temporary function if exists to_libsvm_format;
+create temporary function to_libsvm_format as 'hivemall.ftvec.conv.ToLibSVMFormatUDF';
+
 --------------------------
 -- feature transformers --
 --------------------------
@@ -883,4 +886,3 @@ log(10, n_docs / max2(1,df_t)) + 1.0;
 
 create temporary macro tfidf(tf FLOAT, df_t DOUBLE, n_docs DOUBLE)
 tf * (log(10, n_docs / max2(1,df_t)) + 1.0);
-
diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark
index e3ff216..feadbbf 100644
--- a/resources/ddl/define-all.spark
+++ b/resources/ddl/define-all.spark
@@ -287,6 +287,9 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION build_bins AS 'hivemall.ftvec.binning.
 sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS feature_binning")
 sqlContext.sql("CREATE TEMPORARY FUNCTION feature_binning AS 'hivemall.ftvec.binning.FeatureBinningUDF'")
 
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS to_libsvm_format")
+sqlContext.sql("CREATE TEMPORARY FUNCTION to_libsvm_format AS 'hivemall.ftvec.conv.ToLibSVMFormatUDF'")
+
 /**
  * feature transformers
  */