You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2019/06/20 10:30:17 UTC
[incubator-hivemall] 01/03: Added to_libsvm_format UDF
This is an automated email from the ASF dual-hosted git repository.
myui pushed a commit to branch libsvm
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git
commit 4a8c64c036c77e2edd4ccb72c9f292bc8a8beb0a
Author: Makoto Yui <my...@apache.org>
AuthorDate: Thu Jun 20 19:04:48 2019 +0900
Added to_libsvm_format UDF
---
.../hivemall/ftvec/conv/ToLibSVMFormatUDF.java | 217 +++++++++++++++++++++
.../hivemall/ftvec/conv/ToLibSVMFormatUDFTest.java | 99 ++++++++++
2 files changed, 316 insertions(+)
diff --git a/core/src/main/java/hivemall/ftvec/conv/ToLibSVMFormatUDF.java b/core/src/main/java/hivemall/ftvec/conv/ToLibSVMFormatUDF.java
new file mode 100644
index 0000000..a1c85d9
--- /dev/null
+++ b/core/src/main/java/hivemall/ftvec/conv/ToLibSVMFormatUDF.java
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.ftvec.conv;
+
+import hivemall.UDFWithOptions;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.hashing.MurmurHash3;
+import hivemall.utils.lang.NumberUtils;
+import hivemall.utils.lang.Primitives;
+import hivemall.utils.lang.StringUtils;
+import hivemall.utils.struct.Pair;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.List;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.UDFType;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+
+@Description(name = "to_libsvm_format",
+ value = "_FUNC_(array<string> feautres [, double/integer target, const string options])"
+ + " - Returns a string representation of libsvm")
+@UDFType(deterministic = true, stateful = false)
+public final class ToLibSVMFormatUDF extends UDFWithOptions {
+
+ private ListObjectInspector _featuresOI;
+ @Nullable
+ private PrimitiveObjectInspector _targetOI = null;
+ private int _numFeatures = MurmurHash3.DEFAULT_NUM_FEATURES;
+ private StringBuilder _buf;
+
+ @Override
+ protected Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("features", "num_features", true,
+ "The number of features [default: 16777217 (2^24)]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException {
+ CommandLine cl = parseOptions(optionValue);
+ this._numFeatures = Primitives.parseInt(cl.getOptionValue("num_features"), _numFeatures);
+ return cl;
+ }
+
+ @Override
+ public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ assumeTrue(argOIs.length >= 1 || argOIs.length <= 3,
+ "to_libsvm_format UDF takes 1~3 arguments");
+
+ this._featuresOI = HiveUtils.asListOI(argOIs[0]);
+ if (argOIs.length == 2) {
+ ObjectInspector argOI1 = argOIs[1];
+ if (HiveUtils.isNumberOI(argOI1)) {
+ this._targetOI = HiveUtils.asNumberOI(argOI1);
+ } else if (HiveUtils.isConstString(argOI1)) { // no target
+ String opts = HiveUtils.getConstString(argOI1);
+ processOptions(opts);
+ } else {
+ throw new UDFArgumentException(
+ "Unexpected argument type for 2nd argument: " + argOI1.getTypeName());
+ }
+ } else if (argOIs.length == 3) {
+ this._targetOI = HiveUtils.asNumberOI(argOIs[1]);
+ String opts = HiveUtils.getConstString(argOIs[2]);
+ processOptions(opts);
+ }
+
+ this._buf = new StringBuilder();
+
+ return PrimitiveObjectInspectorFactory.javaStringObjectInspector;
+ }
+
+ @Nullable
+ @Override
+ public String evaluate(DeferredObject[] args) throws HiveException {
+ final StringBuilder buf = this._buf;
+ StringUtils.clear(buf);
+
+ Object arg0 = args[0].get();
+ if (arg0 == null) {
+ return null;
+ }
+
+ final int featureSize = _featuresOI.getListLength(arg0);
+ List<Pair<Integer, Double>> features = new ArrayList<>(featureSize);
+ for (int i = 0; i < featureSize; i++) {
+ Object e = _featuresOI.getListElement(arg0, i);
+ if (e == null) {
+ continue;
+ }
+ Pair<Integer, Double> fv = parse(e.toString(), _numFeatures);
+ features.add(fv);
+ }
+ Collections.sort(features, comparator);
+
+ if (_targetOI != null) {
+ Object arg1 = args[1].get();
+ if (arg1 == null) {
+ throw new HiveException("Detected NULL for the 2nd argument");
+ }
+ if (HiveUtils.isIntegerOI(_targetOI)) {
+ int label = HiveUtils.getInt(arg1, _targetOI);
+ buf.append(label);
+ } else {
+ double label = HiveUtils.getDouble(arg1, _targetOI);
+ buf.append(label);
+ }
+ buf.append(' ');
+ }
+ for (int i = 0, size = features.size(); i < size; i++) {
+ if (i != 0) {
+ buf.append(' ');
+ }
+ Pair<Integer, Double> fv = features.get(i);
+ buf.append(fv.getKey().intValue());
+ buf.append(':');
+ buf.append(fv.getValue().doubleValue());
+ }
+
+ return buf.toString();
+ }
+
+ @Nonnull
+ public static Pair<Integer, Double> parse(@Nonnull final String fv,
+ @Nonnegative final int numFeatures) throws UDFArgumentException {
+ final int headPos = fv.indexOf(':');
+ if (headPos == -1) {
+ if (NumberUtils.isDigits(fv)) {
+ final int f;
+ try {
+ f = Integer.parseInt(fv);
+ } catch (NumberFormatException e) {
+ throw new UDFArgumentException("Invalid feature value: " + fv);
+ }
+ return new Pair<>(f, 1.d);
+ } else {
+ return new Pair<>(mhash(fv, numFeatures), 1.d);
+ }
+ } else {
+ final int tailPos = fv.lastIndexOf(':');
+ if (headPos != tailPos) {
+ throw new UDFArgumentException("Unsupported feature format: " + fv);
+ }
+ String f = fv.substring(0, headPos);
+ String v = fv.substring(headPos + 1);
+ final double d;
+ try {
+ d = Double.parseDouble(v);
+ } catch (NumberFormatException e) {
+ throw new UDFArgumentException("Invalid feature value: " + fv);
+ }
+ if (NumberUtils.isDigits(f)) {
+ final int i;
+ try {
+ i = Integer.parseInt(f);
+ } catch (NumberFormatException e) {
+ throw new UDFArgumentException("Invalid feature value: " + fv);
+ }
+ return new Pair<>(i, d);
+ } else {
+ return new Pair<>(mhash(f, numFeatures), d);
+ }
+ }
+ }
+
+ private static int mhash(@Nonnull final String word, final int numFeatures) {
+ int r = MurmurHash3.murmurhash3_x86_32(word, 0, word.length(), 0x9747b28c) % numFeatures;
+ if (r < 0) {
+ r += numFeatures;
+ }
+ return r + 1;
+ }
+
+ private static final Comparator<Pair<Integer, Double>> comparator =
+ new Comparator<Pair<Integer, Double>>() {
+ @Override
+ public int compare(Pair<Integer, Double> l, Pair<Integer, Double> r) {
+ return l.getKey().compareTo(r.getKey());
+ }
+ };
+
+ @Override
+ public String getDisplayString(String[] args) {
+ return "to_libsvm_format( " + StringUtils.join(args, ',') + " )";
+ }
+}
diff --git a/core/src/test/java/hivemall/ftvec/conv/ToLibSVMFormatUDFTest.java b/core/src/test/java/hivemall/ftvec/conv/ToLibSVMFormatUDFTest.java
new file mode 100644
index 0000000..6a59058
--- /dev/null
+++ b/core/src/test/java/hivemall/ftvec/conv/ToLibSVMFormatUDFTest.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.ftvec.conv;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredJavaObject;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class ToLibSVMFormatUDFTest {
+
+ @Test
+ public void testFeatureOnly() throws IOException, HiveException {
+ ToLibSVMFormatUDF udf = new ToLibSVMFormatUDF();
+
+ udf.initialize(new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+ ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-features 10")});
+
+ Assert.assertEquals("3:2.1 7:3.4", udf.evaluate(new DeferredObject[] {
+ new DeferredJavaObject(Arrays.asList("apple:3.4", "orange:2.1"))}));
+
+ Assert.assertEquals("3:2.1 7:3.4", udf.evaluate(
+ new DeferredObject[] {new DeferredJavaObject(Arrays.asList("7:3.4", "3:2.1"))}));
+
+ udf.close();
+ }
+
+ @Test
+ public void testFeatureAndIntLabel() throws IOException, HiveException {
+ ToLibSVMFormatUDF udf = new ToLibSVMFormatUDF();
+
+ udf.initialize(
+ new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector,
+ ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+ "-features 10")});
+
+ Assert.assertEquals("5 3:2.1 7:3.4",
+ udf.evaluate(new DeferredObject[] {
+ new DeferredJavaObject(Arrays.asList("apple:3.4", "orange:2.1")),
+ new DeferredJavaObject(5)}));
+
+ udf.close();
+ }
+
+ @Test
+ public void testFeatureAndFloatLabel() throws IOException, HiveException {
+ ToLibSVMFormatUDF udf = new ToLibSVMFormatUDF();
+
+ udf.initialize(
+ new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+ PrimitiveObjectInspectorFactory.javaFloatObjectInspector,
+ ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+ "-features 10")});
+
+ Assert.assertEquals("5.0 3:2.1 7:3.4",
+ udf.evaluate(
+ new DeferredObject[] {new DeferredJavaObject(Arrays.asList("7:3.4", "3:2.1")),
+ new DeferredJavaObject(5.f)}));
+
+ udf.close();
+ }
+
+
+
+}