You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2019/12/12 08:32:34 UTC
[incubator-hivemall] branch master updated: [HIVEMALL-288]
mf_predict throws SemanticException No matching method with (array,
array, int)
This is an automated email from the ASF dual-hosted git repository.
myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git
The following commit(s) were added to refs/heads/master by this push:
new f8a2b06 [HIVEMALL-288] mf_predict throws SemanticException No matching method with (array<double>, array<double>, int)
f8a2b06 is described below
commit f8a2b06de1d2c33d1ef1753b3ec8a42a48e6537d
Author: Makoto Yui <my...@apache.org>
AuthorDate: Thu Dec 12 17:32:27 2019 +0900
[HIVEMALL-288] mf_predict throws SemanticException No matching method with (array<double>, array<double>, int)
## What changes were proposed in this pull request?
`mf_predict` throws SemanticException No matching method with (array<double>, array<double>, int)
## What type of PR is it?
Bug Fix
## What is the Jira issue?
https://issues.apache.org/jira/browse/HIVEMALL-288
## How was this patch tested?
manual tests on EMR
```sql
select
-- 3 arguments
mf_predict(array(cast(1.0 as float),cast(2.0 as float),cast(3.0 as float)), array(cast(1.0 as float),cast(2.0 as float),cast(3.0 as float)), 1),
mf_predict(array(1.0,2.0,3.0), array(1.0,2.0,3.0), 1),
mf_predict(array(cast(1.0 as DOUBLE),cast(2.0 as DOUBLE),cast(3.0 as DOUBLE)), array(cast(1.0 as DOUBLE),cast(2.0 as DOUBLE),cast(3.0 as DOUBLE)), 1),
-- 2 arguments
mf_predict(array(1.0,2.0,3.0), array(1.0,2.0,3.0)),
-- 4 arguments
mf_predict(array(1.0,2.0,3.0), array(1.0,2.0,3.0), 0, 0),
-- 5 arguments
mf_predict(array(1.0,2.0,3.0), array(1.0,2.0,3.0), 0, 0, 1);
```
## Checklist
(Please remove this section if not needed; check `x` for YES, blank for NO)
- [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit?
- [x] Did you run system tests on Hive (or Spark)?
Author: Makoto Yui <my...@apache.org>
Closes #224 from myui/HIVEMALL-288.
---
.../hivemall/factorization/mf/MFPredictionUDF.java | 204 ++++++++++++---------
.../main/java/hivemall/utils/hadoop/HiveUtils.java | 20 ++
2 files changed, 142 insertions(+), 82 deletions(-)
diff --git a/core/src/main/java/hivemall/factorization/mf/MFPredictionUDF.java b/core/src/main/java/hivemall/factorization/mf/MFPredictionUDF.java
index c91e0eb..c73e96f 100644
--- a/core/src/main/java/hivemall/factorization/mf/MFPredictionUDF.java
+++ b/core/src/main/java/hivemall/factorization/mf/MFPredictionUDF.java
@@ -18,121 +18,161 @@
*/
package hivemall.factorization.mf;
-import java.util.List;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Preconditions;
-import javax.annotation.Nonnull;
import javax.annotation.Nullable;
+import org.apache.commons.lang.StringUtils;
import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDF;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
-import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
@Description(name = "mf_predict",
- value = "_FUNC_(List<Float> Pu, List<Float> Qi[, double Bu, double Bi[, double mu]]) - Returns the prediction value")
+ value = "_FUNC_(array<double> Pu, array<double> Qi[, double Bu, double Bi[, double mu]]) - Returns the prediction value")
@UDFType(deterministic = true, stateful = false)
-public final class MFPredictionUDF extends UDF {
+public final class MFPredictionUDF extends GenericUDF {
- @Nonnull
- public DoubleWritable evaluate(@Nullable List<FloatWritable> Pu,
- @Nullable List<FloatWritable> Qi) throws HiveException {
- return evaluate(Pu, Qi, null);
- }
+ private ListObjectInspector puOI, qiOI;
+ private PrimitiveObjectInspector puElemOI, qiElemOI;
- @Nonnull
- public DoubleWritable evaluate(@Nullable List<FloatWritable> Pu,
- @Nullable List<FloatWritable> Qi, @Nullable DoubleWritable mu) throws HiveException {
- final double muValue = (mu == null) ? 0.d : mu.get();
- if (Pu == null || Qi == null) {
- return new DoubleWritable(muValue);
- }
+ @Nullable
+ private PrimitiveObjectInspector buOI, biOI, muOI;
- final int PuSize = Pu.size();
- final int QiSize = Qi.size();
- // workaround for TD
- if (PuSize == 0) {
- return new DoubleWritable(muValue);
- } else if (QiSize == 0) {
- return new DoubleWritable(muValue);
+ private DoubleWritable result;
+
+ @Override
+ public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ if (argOIs.length < 2 || argOIs.length > 5) {
+ throw new UDFArgumentException("mf_predict takes 2~5 arguments: " + argOIs.length);
}
- if (QiSize != PuSize) {
- throw new HiveException("|Pu| " + PuSize + " was not equal to |Qi| " + QiSize);
+ this.puOI = HiveUtils.asListOI(argOIs, 0);
+ this.puElemOI = HiveUtils.asFloatingPointOI(puOI.getListElementObjectInspector());
+ this.qiOI = HiveUtils.asListOI(argOIs, 1);
+ this.qiElemOI = HiveUtils.asFloatingPointOI(qiOI.getListElementObjectInspector());
+
+ switch (argOIs.length) {
+ case 3:
+ this.muOI = HiveUtils.asNumberOI(argOIs, 2);
+ break;
+ case 4:
+ this.buOI = HiveUtils.asNumberOI(argOIs, 2);
+ this.biOI = HiveUtils.asNumberOI(argOIs, 3);
+ break;
+ case 5:
+ this.buOI = HiveUtils.asNumberOI(argOIs, 2);
+ this.biOI = HiveUtils.asNumberOI(argOIs, 3);
+ this.muOI = HiveUtils.asNumberOI(argOIs, 4);
+ break;
+ default:
+ break;
}
- double ret = muValue;
- for (int k = 0; k < PuSize; k++) {
- FloatWritable Pu_k = Pu.get(k);
- if (Pu_k == null) {
- continue;
+ this.result = new DoubleWritable();
+ return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
+ }
+
+ @Override
+ public Object evaluate(DeferredObject[] args) throws HiveException {
+ Preconditions.checkArgument(args.length >= 2 && args.length <= 5, args.length);
+
+ @Nullable
+ double[] pu = HiveUtils.asDoubleArray(args[0].get(), puOI, puElemOI);
+ @Nullable
+ double[] qi = HiveUtils.asDoubleArray(args[1].get(), qiOI, qiElemOI);
+
+ double mu = 0.d, bu = 0.d, bi = 0.d;
+ switch (args.length) {
+ case 3: {
+ Object arg2 = args[2].get();
+ if (arg2 != null) {
+ mu = PrimitiveObjectInspectorUtils.getDouble(arg2, muOI);
+ }
+ break;
+ }
+ case 4: {
+ Object arg2 = args[2].get();
+ if (arg2 != null) {
+ bu = PrimitiveObjectInspectorUtils.getDouble(arg2, buOI);
+ }
+ Object arg3 = args[3].get();
+ if (arg3 != null) {
+ bi = PrimitiveObjectInspectorUtils.getDouble(arg3, biOI);
+ }
+ break;
}
- FloatWritable Qi_k = Qi.get(k);
- if (Qi_k == null) {
- continue;
+ case 5: {
+ Object arg2 = args[2].get();
+ if (arg2 != null) {
+ bu = PrimitiveObjectInspectorUtils.getDouble(arg2, buOI);
+ }
+ Object arg3 = args[3].get();
+ if (arg3 != null) {
+ bi = PrimitiveObjectInspectorUtils.getDouble(arg3, biOI);
+ }
+ Object arg4 = args[4].get();
+ if (arg4 != null) {
+ mu = PrimitiveObjectInspectorUtils.getDouble(arg4, muOI);
+ }
+ break;
}
- ret += Pu_k.get() * Qi_k.get();
+ default:
+ break;
}
- return new DoubleWritable(ret);
- }
- @Nonnull
- public DoubleWritable evaluate(@Nullable List<FloatWritable> Pu,
- @Nullable List<FloatWritable> Qi, @Nullable DoubleWritable Bu,
- @Nullable DoubleWritable Bi) throws HiveException {
- return evaluate(Pu, Qi, Bu, Bi, null);
+ double predicted = mfPredict(pu, qi, bu, bi, mu);
+ result.set(predicted);
+ return result;
}
- @Nonnull
- public DoubleWritable evaluate(@Nullable List<FloatWritable> Pu,
- @Nullable List<FloatWritable> Qi, @Nullable DoubleWritable Bu,
- @Nullable DoubleWritable Bi, @Nullable DoubleWritable mu) throws HiveException {
- final double muValue = (mu == null) ? 0.d : mu.get();
- if (Pu == null && Qi == null) {
- return new DoubleWritable(muValue);
- }
- final double BiValue = (Bi == null) ? 0.d : Bi.get();
- final double BuValue = (Bu == null) ? 0.d : Bu.get();
+ private static double mfPredict(@Nullable final double[] Pu, @Nullable final double[] Qi,
+ final double Bu, final double Bi, final double mu) throws UDFArgumentException {
if (Pu == null) {
- double ret = muValue + BiValue;
- return new DoubleWritable(ret);
+ if (Qi == null) {
+ return mu;
+ } else {
+ return mu + Bi;
+ }
} else if (Qi == null) {
- return new DoubleWritable(muValue);
+ return mu + Bu;
}
-
- final int PuSize = Pu.size();
- final int QiSize = Qi.size();
- // workaround for TD
- if (PuSize == 0) {
- if (QiSize == 0) {
- return new DoubleWritable(muValue);
+ // workaround for TD
+ if (Pu.length == 0) {
+ if (Qi.length == 0) {
+ return mu;
} else {
- double ret = muValue + BiValue;
- return new DoubleWritable(ret);
+ return mu + Bi;
}
- } else if (QiSize == 0) {
- double ret = muValue + BuValue;
- return new DoubleWritable(ret);
+ } else if (Qi.length == 0) {
+ return mu + Bu;
}
- if (QiSize != PuSize) {
- throw new HiveException("|Pu| " + PuSize + " was not equal to |Qi| " + QiSize);
+ if (Pu.length != Qi.length) {
+ throw new UDFArgumentException(
+ "|Pu| " + Pu.length + " was not equal to |Qi| " + Qi.length);
}
- double ret = muValue + BuValue + BiValue;
- for (int k = 0; k < PuSize; k++) {
- FloatWritable Pu_k = Pu.get(k);
- if (Pu_k == null) {
- continue;
- }
- FloatWritable Qi_k = Qi.get(k);
- if (Qi_k == null) {
- continue;
- }
- ret += Pu_k.get() * Qi_k.get();
+ double ret = mu + Bu + Bi;
+ for (int k = 0, size = Pu.length; k < size; k++) {
+ double pu_k = Pu[k];
+ double qi_k = Qi[k];
+ ret += pu_k * qi_k;
}
- return new DoubleWritable(ret);
+ return ret;
+ }
+
+ @Override
+ public String getDisplayString(String[] args) {
+ return "mf_predict(" + StringUtils.join(args, ',') + ')';
}
}
diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
index 38b37a4..293d236 100644
--- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
@@ -1302,6 +1302,26 @@ public final class HiveUtils {
}
@Nonnull
+ public static PrimitiveObjectInspector asNumberOI(@Nonnull final ObjectInspector[] argOIs,
+ final int argIndex) throws UDFArgumentException {
+ final PrimitiveObjectInspector oi = asPrimitiveObjectInspector(argOIs, argIndex);
+ switch (oi.getPrimitiveCategory()) {
+ case BYTE:
+ case SHORT:
+ case INT:
+ case LONG:
+ case FLOAT:
+ case DOUBLE:
+ case DECIMAL:
+ break;
+ default:
+ throw new UDFArgumentTypeException(argIndex,
+ "Only numeric argument is accepted but " + oi.getTypeName() + " is passed.");
+ }
+ return oi;
+ }
+
+ @Nonnull
public static PrimitiveObjectInspector asNumberOI(@Nonnull final ObjectInspector argOI)
throws UDFArgumentTypeException {
if (argOI.getCategory() != Category.PRIMITIVE) {