You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/07/26 05:48:59 UTC
[1/8] incubator-hivemall git commit: Implement -ffm option in
`feature_pairs`
Repository: incubator-hivemall
Updated Branches:
refs/heads/HIVEMALL-24-2 [created] f70e7c52e
Implement -ffm option in `feature_pairs`
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/591e3b0f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/591e3b0f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/591e3b0f
Branch: refs/heads/HIVEMALL-24-2
Commit: 591e3b0f255e6523167157ed2e68c9499b4ca2cd
Parents: 1cccf66
Author: Takuya Kitazawa <k....@gmail.com>
Authored: Fri Mar 3 17:48:12 2017 +0900
Committer: Takuya Kitazawa <k....@gmail.com>
Committed: Fri Mar 3 17:48:12 2017 +0900
----------------------------------------------------------------------
.../ftvec/pairing/FeaturePairsUDTF.java | 134 +++++++++++++++----
1 file changed, 109 insertions(+), 25 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/591e3b0f/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java b/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java
index 6aebd64..d814b40 100644
--- a/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java
+++ b/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java
@@ -20,6 +20,7 @@ package hivemall.ftvec.pairing;
import hivemall.UDTFWithOptions;
import hivemall.model.FeatureValue;
+import hivemall.fm.Feature;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hashing.HashFunction;
import hivemall.utils.lang.Preconditions;
@@ -29,6 +30,7 @@ import java.util.List;
import javax.annotation.Nonnull;
+import hivemall.utils.lang.Primitives;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
@@ -50,6 +52,8 @@ public final class FeaturePairsUDTF extends UDTFWithOptions {
private Type _type;
private RowProcessor _proc;
+ private int _numFields;
+ private int _numFeatures;
public FeaturePairsUDTF() {}
@@ -60,6 +64,9 @@ public final class FeaturePairsUDTF extends UDTFWithOptions {
"Generate feature pairs for Kernel-Expansion Passive Aggressive [default:true]");
opts.addOption("ffm", false,
"Generate feature pairs for Field-aware Factorization Machines [default:false]");
+ opts.addOption("feature_hashing", true,
+ "The number of bits for feature hashing in range [18,31] [default:21]");
+ opts.addOption("num_fields", true, "The number of fields [default:1024]");
return opts;
}
@@ -70,13 +77,17 @@ public final class FeaturePairsUDTF extends UDTFWithOptions {
String args = HiveUtils.getConstString(argOIs[1]);
cl = parseOptions(args);
- Preconditions.checkArgument(cl.getOptions().length == 1, UDFArgumentException.class,
- "Only one option can be specified: " + cl.getArgList());
+ Preconditions.checkArgument(cl.getOptions().length <= 3, UDFArgumentException.class,
+ "Too many options were specified: " + cl.getArgList());
if (cl.hasOption("kpa")) {
this._type = Type.kpa;
} else if (cl.hasOption("ffm")) {
this._type = Type.ffm;
+ int featureBits = Primitives.parseInt(
+ cl.getOptionValue("feature_hashing"), Feature.DEFAULT_FEATURE_BITS);
+ this._numFeatures = 1 << featureBits;
+ this._numFields = Primitives.parseInt(cl.getOptionValue("num_fields"), Feature.DEFAULT_NUM_FIELDS);
} else {
throw new UDFArgumentException("Unsupported option: " + cl.getArgList().get(0));
}
@@ -113,8 +124,16 @@ public final class FeaturePairsUDTF extends UDTFWithOptions {
break;
}
case ffm: {
- throw new UDFArgumentException("-ffm is not supported yet");
- //break;
+ this._proc = new FFMProcessor(fvOI);
+ fieldNames.add("i"); // <ei, jField> index
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldNames.add("j"); // <ej, iField> index
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldNames.add("xi");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+ fieldNames.add("xj");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+ break;
}
default:
throw new UDFArgumentException("Illegal condition: " + _type);
@@ -144,26 +163,7 @@ public final class FeaturePairsUDTF extends UDTFWithOptions {
this.fvOI = fvOI;
}
- void process(@Nonnull Object arg) throws HiveException {
- final int size = fvOI.getListLength(arg);
- if (size == 0) {
- return;
- }
-
- final List<FeatureValue> features = new ArrayList<FeatureValue>(size);
- for (int i = 0; i < size; i++) {
- Object f = fvOI.getListElement(arg, i);
- if (f == null) {
- continue;
- }
- FeatureValue fv = FeatureValue.parse(f, true);
- features.add(fv);
- }
-
- process(features);
- }
-
- abstract void process(@Nonnull List<FeatureValue> features) throws HiveException;
+ abstract void process(@Nonnull Object arg) throws HiveException;
}
@@ -186,7 +186,22 @@ public final class FeaturePairsUDTF extends UDTFWithOptions {
}
@Override
- void process(@Nonnull List<FeatureValue> features) throws HiveException {
+ void process(@Nonnull Object arg) throws HiveException {
+ final int size = fvOI.getListLength(arg);
+ if (size == 0) {
+ return;
+ }
+
+ final List<FeatureValue> features = new ArrayList<FeatureValue>(size);
+ for (int i = 0; i < size; i++) {
+ Object f = fvOI.getListElement(arg, i);
+ if (f == null) {
+ continue;
+ }
+ FeatureValue fv = FeatureValue.parse(f, true);
+ features.add(fv);
+ }
+
forward[0] = f0;
f0.set(0);
forward[1] = null;
@@ -222,6 +237,75 @@ public final class FeaturePairsUDTF extends UDTFWithOptions {
}
}
+ final class FFMProcessor extends RowProcessor {
+
+ @Nonnull
+ private final IntWritable f0, f1;
+ @Nonnull
+ private final DoubleWritable f2, f3;
+ @Nonnull
+ private final Writable[] forward;
+
+ FFMProcessor(@Nonnull ListObjectInspector fvOI) {
+ super(fvOI);
+ this.f0 = new IntWritable();
+ this.f1 = new IntWritable();
+ this.f2 = new DoubleWritable();
+ this.f3 = new DoubleWritable();
+ this.forward = new Writable[] {f0, null, null, null};
+ }
+
+ @Override
+ void process(@Nonnull Object arg) throws HiveException {
+ final int size = fvOI.getListLength(arg);
+ if (size == 0) {
+ return;
+ }
+
+ final Feature[] features = Feature.parseFFMFeatures(arg, fvOI, null, _numFeatures, _numFields);
+
+ // W0
+ forward[0] = f0;
+ f0.set(0);
+ forward[1] = null;
+ forward[2] = null;
+ forward[3] = null;
+ forward(forward);
+
+ forward[2] = f2;
+ for (int i = 0, len = features.length; i < len; i++) {
+ Feature ei = features[i];
+ double xi = ei.getValue();
+ int iField = ei.getField();
+
+ // Wi
+ forward[0] = f0;
+ f0.set(i);
+ forward[1] = null;
+ f2.set(xi);
+ forward[3] = null;
+ forward(forward);
+
+ forward[1] = f1;
+ forward[3] = f3;
+ for (int j = i + 1; j < len; j++) {
+ Feature ej = features[j];
+ double xj = ej.getValue();
+ int jField = ej.getField();
+
+ int ifj = Feature.toIntFeature(ei, jField, _numFields);
+ int jfi = Feature.toIntFeature(ej, iField, _numFields);
+
+ // Vifj, Vjfi
+ f0.set(ifj);
+ f1.set(jfi);
+ // `f2` is consistently set to `xi`
+ f3.set(xj);
+ forward(forward);
+ }
+ }
+ }
+ }
@Override
public void close() throws HiveException {
[3/8] incubator-hivemall git commit: Merge branch 'ffm-predict-udaf'
of https://github.com/takuti/incubator-hivemall into HIVEMALL-24-2
Posted by my...@apache.org.
Merge branch 'ffm-predict-udaf' of https://github.com/takuti/incubator-hivemall into HIVEMALL-24-2
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/82fa5810
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/82fa5810
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/82fa5810
Branch: refs/heads/HIVEMALL-24-2
Commit: 82fa5810c459c89c56cb83b06c62a044e34e2dee
Parents: 7205de1 343f704
Author: Makoto Yui <my...@apache.org>
Authored: Fri Jul 21 15:38:30 2017 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Fri Jul 21 15:38:30 2017 +0900
----------------------------------------------------------------------
.../java/hivemall/fm/FFMPredictGenericUDAF.java | 272 +++++++++++++++++++
.../ftvec/pairing/FeaturePairsUDTF.java | 134 +++++++--
2 files changed, 381 insertions(+), 25 deletions(-)
----------------------------------------------------------------------
[6/8] incubator-hivemall git commit: Changed FFM prediction model as
a scalable format
Posted by my...@apache.org.
Changed FFM prediction model as a scalable format
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/550bb4e6
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/550bb4e6
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/550bb4e6
Branch: refs/heads/HIVEMALL-24-2
Commit: 550bb4e6f0f69cfd25506c34caf67e3c014d5750
Parents: 36a6ca2
Author: Makoto Yui <my...@apache.org>
Authored: Tue Jul 25 20:03:29 2017 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Tue Jul 25 20:03:29 2017 +0900
----------------------------------------------------------------------
core/src/main/java/hivemall/fm/Entry.java | 2 +-
.../java/hivemall/fm/FFMPredictGenericUDAF.java | 21 +-
.../main/java/hivemall/fm/FFMPredictUDF.java | 187 ----------
.../java/hivemall/fm/FFMPredictionModel.java | 349 -------------------
.../hivemall/fm/FFMStringFeatureMapModel.java | 53 ++-
.../fm/FieldAwareFactorizationMachineUDTF.java | 82 +++--
.../hivemall/fm/FFMPredictionModelTest.java | 65 ----
resources/ddl/define-all-as-permanent.hive | 2 +-
resources/ddl/define-all.hive | 2 +-
resources/ddl/define-all.spark | 2 +-
resources/ddl/define-udfs.td.hql | 2 +
11 files changed, 111 insertions(+), 656 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/core/src/main/java/hivemall/fm/Entry.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/Entry.java b/core/src/main/java/hivemall/fm/Entry.java
index 1882f85..209112c 100644
--- a/core/src/main/java/hivemall/fm/Entry.java
+++ b/core/src/main/java/hivemall/fm/Entry.java
@@ -58,7 +58,7 @@ class Entry {
return _offset;
}
- void setOffset(long offset) {
+ void setOffset(final long offset) {
this._offset = offset;
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java b/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java
index 91d1b6b..a37a1b8 100644
--- a/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java
+++ b/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java
@@ -19,7 +19,12 @@
package hivemall.fm;
import hivemall.utils.hadoop.HiveUtils;
-import hivemall.utils.hadoop.WritableUtils;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import javax.annotation.Nonnull;
+
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
@@ -30,20 +35,18 @@ import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
-import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray;
-import org.apache.hadoop.hive.serde2.objectinspector.*;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructField;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableDoubleObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
-import javax.annotation.Nonnull;
-import javax.annotation.Nullable;
-import java.util.ArrayList;
-import java.util.List;
-
@Description(
name = "ffm_predict",
value = "_FUNC_(Float Wi, Float Wj, array<float> Vifj, array<float> Vjfi, float Xi, float Xj)"
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/core/src/main/java/hivemall/fm/FFMPredictUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FFMPredictUDF.java b/core/src/main/java/hivemall/fm/FFMPredictUDF.java
deleted file mode 100644
index 48745d9..0000000
--- a/core/src/main/java/hivemall/fm/FFMPredictUDF.java
+++ /dev/null
@@ -1,187 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.fm;
-
-import hivemall.annotations.Experimental;
-import hivemall.utils.hadoop.HiveUtils;
-import hivemall.utils.lang.NumberUtils;
-
-import java.io.IOException;
-import java.util.Arrays;
-
-import javax.annotation.Nonnull;
-import javax.annotation.Nullable;
-
-import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.ql.udf.UDFType;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
-import org.apache.hadoop.hive.serde2.io.DoubleWritable;
-import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray;
-import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
-import org.apache.hadoop.io.Text;
-
-/**
- * @since v0.5-rc.1
- */
-@Description(name = "ffm_predict",
- value = "_FUNC_(string modelId, string model, array<string> features)"
- + " returns a prediction result in double from a Field-aware Factorization Machine")
-@UDFType(deterministic = true, stateful = false)
-@Experimental
-public final class FFMPredictUDF extends GenericUDF {
-
- private StringObjectInspector _modelIdOI;
- private StringObjectInspector _modelOI;
- private ListObjectInspector _featureListOI;
-
- private DoubleWritable _result;
- @Nullable
- private String _cachedModeId;
- @Nullable
- private FFMPredictionModel _cachedModel;
- @Nullable
- private Feature[] _probes;
-
- public FFMPredictUDF() {}
-
- @Override
- public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
- if (argOIs.length != 3) {
- throw new UDFArgumentException("_FUNC_ takes 3 arguments");
- }
- this._modelIdOI = HiveUtils.asStringOI(argOIs[0]);
- this._modelOI = HiveUtils.asStringOI(argOIs[1]);
- this._featureListOI = HiveUtils.asListOI(argOIs[2]);
-
- this._result = new DoubleWritable();
- return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
- }
-
- @Override
- public Object evaluate(DeferredObject[] args) throws HiveException {
- String modelId = _modelIdOI.getPrimitiveJavaObject(args[0].get());
- if (modelId == null) {
- throw new HiveException("modelId is not set");
- }
-
- final FFMPredictionModel model;
- if (modelId.equals(_cachedModeId)) {
- model = this._cachedModel;
- } else {
- Text serModel = _modelOI.getPrimitiveWritableObject(args[1].get());
- if (serModel == null) {
- throw new HiveException("Model is null for model ID: " + modelId);
- }
- byte[] b = serModel.getBytes();
- final int length = serModel.getLength();
- try {
- model = FFMPredictionModel.deserialize(b, length);
- b = null;
- } catch (ClassNotFoundException e) {
- throw new HiveException(e);
- } catch (IOException e) {
- throw new HiveException(e);
- }
- this._cachedModeId = modelId;
- this._cachedModel = model;
- }
-
- int numFeatures = model.getNumFeatures();
- int numFields = model.getNumFields();
-
- Object arg2 = args[2].get();
- // [workaround]
- // java.lang.ClassCastException: org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray
- // cannot be cast to [Ljava.lang.Object;
- if (arg2 instanceof LazyBinaryArray) {
- arg2 = ((LazyBinaryArray) arg2).getList();
- }
- Feature[] x = Feature.parseFFMFeatures(arg2, _featureListOI, _probes, numFeatures,
- numFields);
- if (x == null || x.length == 0) {
- return null; // return NULL if there are no features
- }
- this._probes = x;
-
- double predicted = predict(x, model);
- _result.set(predicted);
- return _result;
- }
-
- private static double predict(@Nonnull final Feature[] x,
- @Nonnull final FFMPredictionModel model) throws HiveException {
- // w0
- double ret = model.getW0();
- // W
- for (Feature e : x) {
- double xi = e.getValue();
- float wi = model.getW(e);
- double wx = wi * xi;
- ret += wx;
- }
- // V
- final int factors = model.getNumFactors();
- final float[] vij = new float[factors];
- final float[] vji = new float[factors];
- for (int i = 0; i < x.length; ++i) {
- final Feature ei = x[i];
- final double xi = ei.getValue();
- final int iField = ei.getField();
- for (int j = i + 1; j < x.length; ++j) {
- final Feature ej = x[j];
- final double xj = ej.getValue();
- final int jField = ej.getField();
- if (!model.getV(ei, jField, vij)) {
- continue;
- }
- if (!model.getV(ej, iField, vji)) {
- continue;
- }
- for (int f = 0; f < factors; f++) {
- float vijf = vij[f];
- float vjif = vji[f];
- ret += vijf * vjif * xi * xj;
- }
- }
- }
- if (!NumberUtils.isFinite(ret)) {
- throw new HiveException("Detected " + ret + " in ffm_predict");
- }
- return ret;
- }
-
- @Override
- public void close() throws IOException {
- super.close();
- // clean up to help GC
- this._cachedModel = null;
- this._probes = null;
- }
-
- @Override
- public String getDisplayString(String[] args) {
- return "ffm_predict(" + Arrays.toString(args) + ")";
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/core/src/main/java/hivemall/fm/FFMPredictionModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FFMPredictionModel.java b/core/src/main/java/hivemall/fm/FFMPredictionModel.java
deleted file mode 100644
index befbec9..0000000
--- a/core/src/main/java/hivemall/fm/FFMPredictionModel.java
+++ /dev/null
@@ -1,349 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.fm;
-
-import hivemall.utils.buffer.HeapBuffer;
-import hivemall.utils.codec.VariableByteCodec;
-import hivemall.utils.codec.ZigZagLEB128Codec;
-import hivemall.utils.collections.maps.Int2LongOpenHashTable;
-import hivemall.utils.collections.maps.IntOpenHashTable;
-import hivemall.utils.io.CompressionStreamFactory.CompressionAlgorithm;
-import hivemall.utils.io.IOUtils;
-import hivemall.utils.lang.ArrayUtils;
-import hivemall.utils.lang.HalfFloat;
-import hivemall.utils.lang.ObjectUtils;
-
-import java.io.DataInput;
-import java.io.DataOutput;
-import java.io.Externalizable;
-import java.io.IOException;
-import java.io.ObjectInput;
-import java.io.ObjectOutput;
-import java.util.Arrays;
-
-import javax.annotation.Nonnull;
-import javax.annotation.Nullable;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-
-public final class FFMPredictionModel implements Externalizable {
- private static final Log LOG = LogFactory.getLog(FFMPredictionModel.class);
-
- private static final byte HALF_FLOAT_ENTRY = 1;
- private static final byte W_ONLY_HALF_FLOAT_ENTRY = 2;
- private static final byte FLOAT_ENTRY = 3;
- private static final byte W_ONLY_FLOAT_ENTRY = 4;
-
- /**
- * maps feature to feature weight pointer
- */
- private Int2LongOpenHashTable _map;
- private HeapBuffer _buf;
-
- private double _w0;
- private int _factors;
- private int _numFeatures;
- private int _numFields;
-
- public FFMPredictionModel() {}// for Externalizable
-
- public FFMPredictionModel(@Nonnull Int2LongOpenHashTable map, @Nonnull HeapBuffer buf,
- double w0, int factor, int numFeatures, int numFields) {
- this._map = map;
- this._buf = buf;
- this._w0 = w0;
- this._factors = factor;
- this._numFeatures = numFeatures;
- this._numFields = numFields;
- }
-
- public int getNumFactors() {
- return _factors;
- }
-
- public double getW0() {
- return _w0;
- }
-
- public int getNumFeatures() {
- return _numFeatures;
- }
-
- public int getNumFields() {
- return _numFields;
- }
-
- public int getActualNumFeatures() {
- return _map.size();
- }
-
- public long approxBytesConsumed() {
- int size = _map.size();
-
- // [map] size * (|state| + |key| + |entry|)
- long bytes = size * (1L + 4L + 4L + (4L * _factors));
- int rest = _map.capacity() - size;
- if (rest > 0) {
- bytes += rest * 1L;
- }
- // w0, factors, numFeatures, numFields, used, size
- bytes += (8 + 4 + 4 + 4 + 4 + 4);
- return bytes;
- }
-
- @Nullable
- private Entry getEntry(final int key) {
- final long ptr = _map.get(key);
- if (ptr == -1L) {
- return null;
- }
- return new Entry(_buf, _factors, ptr);
- }
-
- public float getW(@Nonnull final Feature x) {
- int j = x.getFeatureIndex();
-
- Entry entry = getEntry(j);
- if (entry == null) {
- return 0.f;
- }
- return entry.getW();
- }
-
- /**
- * @return true if V exists
- */
- public boolean getV(@Nonnull final Feature x, @Nonnull final int yField, @Nonnull float[] dst) {
- int j = Feature.toIntFeature(x, yField, _numFields);
-
- Entry entry = getEntry(j);
- if (entry == null) {
- return false;
- }
-
- entry.getV(dst);
- if (ArrayUtils.equals(dst, 0.f)) {
- return false; // treat as null
- }
- return true;
- }
-
- @Override
- public void writeExternal(@Nonnull ObjectOutput out) throws IOException {
- out.writeDouble(_w0);
- final int factors = _factors;
- out.writeInt(factors);
- out.writeInt(_numFeatures);
- out.writeInt(_numFields);
-
- int used = _map.size();
- out.writeInt(used);
-
- final int[] keys = _map.getKeys();
- final int size = keys.length;
- out.writeInt(size);
-
- final byte[] states = _map.getStates();
- writeStates(states, out);
-
- final long[] values = _map.getValues();
-
- final HeapBuffer buf = _buf;
- final Entry e = new Entry(buf, factors);
- final float[] Vf = new float[factors];
- for (int i = 0; i < size; i++) {
- if (states[i] != IntOpenHashTable.FULL) {
- continue;
- }
- ZigZagLEB128Codec.writeSignedInt(keys[i], out);
- e.setOffset(values[i]);
- writeEntry(e, factors, Vf, out);
- }
-
- // help GC
- this._map = null;
- this._buf = null;
- }
-
- private static void writeEntry(@Nonnull final Entry e, final int factors,
- @Nonnull final float[] Vf, @Nonnull final DataOutput out) throws IOException {
- final float W = e.getW();
- e.getV(Vf);
-
- if (ArrayUtils.almostEquals(Vf, 0.f)) {
- if (HalfFloat.isRepresentable(W)) {
- out.writeByte(W_ONLY_HALF_FLOAT_ENTRY);
- out.writeShort(HalfFloat.floatToHalfFloat(W));
- } else {
- out.writeByte(W_ONLY_FLOAT_ENTRY);
- out.writeFloat(W);
- }
- } else if (isRepresentableAsHalfFloat(W, Vf)) {
- out.writeByte(HALF_FLOAT_ENTRY);
- out.writeShort(HalfFloat.floatToHalfFloat(W));
- for (int i = 0; i < factors; i++) {
- out.writeShort(HalfFloat.floatToHalfFloat(Vf[i]));
- }
- } else {
- out.writeByte(FLOAT_ENTRY);
- out.writeFloat(W);
- IOUtils.writeFloats(Vf, factors, out);
- }
- }
-
- private static boolean isRepresentableAsHalfFloat(final float W, @Nonnull final float[] Vf) {
- if (!HalfFloat.isRepresentable(W)) {
- return false;
- }
- for (float V : Vf) {
- if (!HalfFloat.isRepresentable(V)) {
- return false;
- }
- }
- return true;
- }
-
- @Nonnull
- static void writeStates(@Nonnull final byte[] status, @Nonnull final DataOutput out)
- throws IOException {
- // write empty states's indexes differentially
- final int size = status.length;
- int cardinarity = 0;
- for (int i = 0; i < size; i++) {
- if (status[i] != IntOpenHashTable.FULL) {
- cardinarity++;
- }
- }
- out.writeInt(cardinarity);
- if (cardinarity == 0) {
- return;
- }
- int prev = 0;
- for (int i = 0; i < size; i++) {
- if (status[i] != IntOpenHashTable.FULL) {
- int diff = i - prev;
- assert (diff >= 0);
- VariableByteCodec.encodeUnsignedInt(diff, out);
- prev = i;
- }
- }
- }
-
- @Override
- public void readExternal(@Nonnull final ObjectInput in) throws IOException,
- ClassNotFoundException {
- this._w0 = in.readDouble();
- final int factors = in.readInt();
- this._factors = factors;
- this._numFeatures = in.readInt();
- this._numFields = in.readInt();
-
- final int used = in.readInt();
- final int size = in.readInt();
- final int[] keys = new int[size];
- final long[] values = new long[size];
- final byte[] states = new byte[size];
- readStates(in, states);
-
- final int entrySize = Entry.sizeOf(factors);
- int numChunks = (entrySize * used) / HeapBuffer.DEFAULT_CHUNK_BYTES + 1;
- final HeapBuffer buf = new HeapBuffer(HeapBuffer.DEFAULT_CHUNK_SIZE, numChunks);
- final Entry e = new Entry(buf, factors);
- final float[] Vf = new float[factors];
- for (int i = 0; i < size; i++) {
- if (states[i] != IntOpenHashTable.FULL) {
- continue;
- }
- keys[i] = ZigZagLEB128Codec.readSignedInt(in);
- long ptr = buf.allocate(entrySize);
- e.setOffset(ptr);
- readEntry(in, factors, Vf, e);
- values[i] = ptr;
- }
-
- this._map = new Int2LongOpenHashTable(keys, values, states, used);
- this._buf = buf;
- }
-
- @Nonnull
- private static void readEntry(@Nonnull final DataInput in, final int factors,
- @Nonnull final float[] Vf, @Nonnull Entry dst) throws IOException {
- final byte type = in.readByte();
- switch (type) {
- case HALF_FLOAT_ENTRY: {
- float W = HalfFloat.halfFloatToFloat(in.readShort());
- dst.setW(W);
- for (int i = 0; i < factors; i++) {
- Vf[i] = HalfFloat.halfFloatToFloat(in.readShort());
- }
- dst.setV(Vf);
- break;
- }
- case W_ONLY_HALF_FLOAT_ENTRY: {
- float W = HalfFloat.halfFloatToFloat(in.readShort());
- dst.setW(W);
- break;
- }
- case FLOAT_ENTRY: {
- float W = in.readFloat();
- dst.setW(W);
- IOUtils.readFloats(in, Vf);
- dst.setV(Vf);
- break;
- }
- case W_ONLY_FLOAT_ENTRY: {
- float W = in.readFloat();
- dst.setW(W);
- break;
- }
- default:
- throw new IOException("Unexpected Entry type: " + type);
- }
- }
-
- @Nonnull
- static void readStates(@Nonnull final DataInput in, @Nonnull final byte[] status)
- throws IOException {
- // read non-empty states differentially
- final int cardinarity = in.readInt();
- Arrays.fill(status, IntOpenHashTable.FULL);
- int prev = 0;
- for (int j = 0; j < cardinarity; j++) {
- int i = VariableByteCodec.decodeUnsignedInt(in) + prev;
- status[i] = IntOpenHashTable.FREE;
- prev = i;
- }
- }
-
- public byte[] serialize() throws IOException {
- LOG.info("FFMPredictionModel#serialize(): " + _buf.toString());
- return ObjectUtils.toCompressedBytes(this, CompressionAlgorithm.lzma2, true);
- }
-
- public static FFMPredictionModel deserialize(@Nonnull final byte[] serializedObj, final int len)
- throws ClassNotFoundException, IOException {
- FFMPredictionModel model = new FFMPredictionModel();
- ObjectUtils.readCompressedObject(serializedObj, len, model, CompressionAlgorithm.lzma2,
- true);
- LOG.info("FFMPredictionModel#deserialize(): " + model._buf.toString());
- return model;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java b/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java
index 4f445fa..2264063 100644
--- a/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java
+++ b/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java
@@ -23,6 +23,7 @@ import hivemall.fm.Entry.FTRLEntry;
import hivemall.fm.FMHyperParameters.FFMHyperParameters;
import hivemall.utils.buffer.HeapBuffer;
import hivemall.utils.collections.maps.Int2LongOpenHashTable;
+import hivemall.utils.collections.maps.Int2LongOpenHashTable.IMapIterator;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.math.MathUtils;
@@ -39,7 +40,7 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi
private final HeapBuffer _buf;
// hyperparams
- private final int _numFeatures;
+ // private final int _numFeatures;
private final int _numFields;
// FTEL
@@ -55,7 +56,7 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi
this._w0 = 0.f;
this._map = new Int2LongOpenHashTable(DEFAULT_MAPSIZE);
this._buf = new HeapBuffer(HeapBuffer.DEFAULT_CHUNK_SIZE);
- this._numFeatures = params.numFeatures;
+ // this._numFeatures = params.numFeatures;
this._numFields = params.numFields;
this._alpha = params.alphaFTRL;
this._beta = params.betaFTRL;
@@ -64,11 +65,6 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi
this._entrySize = entrySize(_factor, _useFTRL, _useAdaGrad);
}
- @Nonnull
- FFMPredictionModel toPredictionModel() {
- return new FFMPredictionModel(_map, _buf, _w0, _factor, _numFeatures, _numFields);
- }
-
@Override
public int getSize() {
return _map.size();
@@ -271,6 +267,49 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi
}
}
+ @Nonnull
+ EntryIterator entries() {
+ return new EntryIterator(this);
+ }
+
+ static final class EntryIterator {
+
+ @Nonnull
+ private final IMapIterator dictItor;
+ @Nonnull
+ private final Entry entryProbe;
+
+ EntryIterator(@Nonnull FFMStringFeatureMapModel model) {
+ this.dictItor = model._map.entries();
+ this.entryProbe = new Entry(model._buf, model._factor);
+ }
+
+ @Nonnull
+ Entry getEntryProbe() {
+ return entryProbe;
+ }
+
+ boolean hasNext() {
+ return dictItor.hasNext();
+ }
+
+ int next() {
+ return dictItor.next();
+ }
+
+ int getEntryIndex() {
+ return dictItor.getKey();
+ }
+
+ @Nonnull
+ void getEntry(@Nonnull final Entry probe) {
+ long offset = dictItor.getValue();
+ probe.setOffset(offset);
+ }
+
+ }
+
+
private static int entrySize(int factors, boolean ftrl, boolean adagrad) {
if (ftrl) {
return FTRLEntry.sizeOf(factors);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java
index 67dbf87..d1c9e73 100644
--- a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java
+++ b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java
@@ -18,25 +18,23 @@
*/
package hivemall.fm;
+import hivemall.fm.FFMStringFeatureMapModel.EntryIterator;
import hivemall.fm.FMHyperParameters.FFMHyperParameters;
import hivemall.utils.collections.arrays.DoubleArray3D;
import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.hadoop.HadoopUtils;
-import hivemall.utils.hadoop.Text3;
-import hivemall.utils.lang.NumberUtils;
+import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.math.MathUtils;
-import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
+import java.util.Arrays;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
@@ -44,6 +42,8 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
/**
@@ -56,7 +56,6 @@ import org.apache.hadoop.io.Text;
name = "train_ffm",
value = "_FUNC_(array<string> x, double y [, const string options]) - Returns a prediction model")
public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachineUDTF {
- private static final Log LOG = LogFactory.getLog(FieldAwareFactorizationMachineUDTF.class);
// ----------------------------------------
// Learning hyper-parameters/options
@@ -150,8 +149,14 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi
fieldNames.add("model_id");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
- fieldNames.add("model");
- fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
+ fieldNames.add("i");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+
+ fieldNames.add("Wi");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+ fieldNames.add("Vi");
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableFloatObjectInspector));
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}
@@ -267,39 +272,46 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi
this._fieldList = null;
this._sumVfX = null;
- Text modelId = new Text();
- String taskId = HadoopUtils.getUniqueTaskIdString();
- modelId.set(taskId);
+ final int factors = _factors;
+ final IntWritable idx = new IntWritable();
+ final FloatWritable Wi = new FloatWritable(0.f);
+ final FloatWritable[] Vi = HiveUtils.newFloatArray(factors, 0.f);
+
+ final Object[] forwardObjs = new Object[4];
+ String modelId = HadoopUtils.getUniqueTaskIdString();
+ forwardObjs[0] = new Text(modelId);
+ forwardObjs[1] = idx;
+ forwardObjs[2] = Wi;
+ forwardObjs[3] = null; // Vi
+
+ // W0
+ idx.set(0);
+ Wi.set(_ffmModel.getW0());
+ forward(forwardObjs);
- FFMPredictionModel predModel = _ffmModel.toPredictionModel();
- this._ffmModel = null; // help GC
+ forwardObjs[3] = Arrays.asList(Vi);
- if (LOG.isInfoEnabled()) {
- LOG.info("Serializing a model '" + modelId + "'... Configured # features: "
- + _numFeatures + ", Configured # fields: " + _numFields
- + ", Actual # features: " + predModel.getActualNumFeatures()
- + ", Estimated uncompressed bytes: "
- + NumberUtils.prettySize(predModel.approxBytesConsumed()));
- }
+ final EntryIterator itor = _ffmModel.entries();
+ final Entry entry = itor.getEntryProbe();
+ final float[] Vf = new float[factors];
+ while (itor.next() != -1) {
+ // set i
+ int i = itor.getEntryIndex();
+ idx.set(i);
- byte[] serialized;
- try {
- serialized = predModel.serialize();
- predModel = null;
- } catch (IOException e) {
- throw new HiveException("Failed to serialize a model", e);
- }
+ itor.getEntry(entry);
- if (LOG.isInfoEnabled()) {
- LOG.info("Forwarding a serialized/compressed model '" + modelId + "' of size: "
- + NumberUtils.prettySize(serialized.length));
- }
+ // set Wi
+ Wi.set(entry.getW());
- Text modelObj = new Text3(serialized);
- serialized = null;
- Object[] forwardObjs = new Object[] {modelId, modelObj};
+ // set Vif
+ entry.getV(Vf);
+ for (int f = 0; f < factors; f++) {
+ Vi[f].set(Vf[f]);
+ }
- forward(forwardObjs);
+ forward(forwardObjs);
+ }
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java b/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java
deleted file mode 100644
index 076387f..0000000
--- a/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java
+++ /dev/null
@@ -1,65 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.fm;
-
-import hivemall.utils.buffer.HeapBuffer;
-import hivemall.utils.collections.maps.Int2LongOpenHashTable;
-
-import java.io.IOException;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class FFMPredictionModelTest {
-
- @Test
- public void testSerialize() throws IOException, ClassNotFoundException {
- final int factors = 3;
- final int entrySize = Entry.sizeOf(factors);
-
- HeapBuffer buf = new HeapBuffer(HeapBuffer.DEFAULT_CHUNK_SIZE);
- Int2LongOpenHashTable map = Int2LongOpenHashTable.newInstance();
-
- Entry e1 = new Entry(buf, factors, buf.allocate(entrySize));
- e1.setW(1f);
- e1.setV(new float[] {1f, -1f, -1f});
-
- Entry e2 = new Entry(buf, factors, buf.allocate(entrySize));
- e2.setW(2f);
- e2.setV(new float[] {1f, 2f, -1f});
-
- Entry e3 = new Entry(buf, factors, buf.allocate(entrySize));
- e3.setW(3f);
- e3.setV(new float[] {1f, 2f, 3f});
-
- map.put(1, e1.getOffset());
- map.put(2, e2.getOffset());
- map.put(3, e3.getOffset());
-
- FFMPredictionModel expected = new FFMPredictionModel(map, buf, 0.d, 3,
- Feature.DEFAULT_NUM_FEATURES, Feature.DEFAULT_NUM_FIELDS);
- byte[] b = expected.serialize();
-
- FFMPredictionModel actual = FFMPredictionModel.deserialize(b, b.length);
- Assert.assertEquals(3, actual.getNumFactors());
- Assert.assertEquals(Feature.DEFAULT_NUM_FEATURES, actual.getNumFeatures());
- Assert.assertEquals(Feature.DEFAULT_NUM_FIELDS, actual.getNumFields());
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/resources/ddl/define-all-as-permanent.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive
index feb1a08..f3065c0 100644
--- a/resources/ddl/define-all-as-permanent.hive
+++ b/resources/ddl/define-all-as-permanent.hive
@@ -620,7 +620,7 @@ DROP FUNCTION IF EXISTS train_ffm;
CREATE FUNCTION train_ffm as 'hivemall.fm.FieldAwareFactorizationMachineUDTF' USING JAR '${hivemall_jar}';
DROP FUNCTION IF EXISTS ffm_predict;
-CREATE FUNCTION ffm_predict as 'hivemall.fm.FFMPredictUDF' USING JAR '${hivemall_jar}';
+CREATE FUNCTION ffm_predict as 'hivemall.fm.FFMPredictGenericUDAF' USING JAR '${hivemall_jar}';
---------------------------
-- Anomaly Detection ------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/resources/ddl/define-all.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive
index 310f9f4..305afc9 100644
--- a/resources/ddl/define-all.hive
+++ b/resources/ddl/define-all.hive
@@ -612,7 +612,7 @@ drop temporary function if exists train_ffm;
create temporary function train_ffm as 'hivemall.fm.FieldAwareFactorizationMachineUDTF';
drop temporary function if exists ffm_predict;
-create temporary function ffm_predict as 'hivemall.fm.FFMPredictUDF';
+create temporary function ffm_predict as 'hivemall.fm.FFMPredictGenericUDAF';
---------------------------
-- Anomaly Detection ------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/resources/ddl/define-all.spark
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark
index 42b235b..4e18b8f 100644
--- a/resources/ddl/define-all.spark
+++ b/resources/ddl/define-all.spark
@@ -596,7 +596,7 @@ sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS train_ffm")
sqlContext.sql("CREATE TEMPORARY FUNCTION train_ffm AS 'hivemall.fm.FieldAwareFactorizationMachineUDTF'")
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS ffm_predict")
-sqlContext.sql("CREATE TEMPORARY FUNCTION ffm_predict AS 'hivemall.fm.FFMPredictUDF'")
+sqlContext.sql("CREATE TEMPORARY FUNCTION ffm_predict AS 'hivemall.fm.FFMPredictGenericUDAF'")
/**
* Anomaly Detection
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/resources/ddl/define-udfs.td.hql
----------------------------------------------------------------------
diff --git a/resources/ddl/define-udfs.td.hql b/resources/ddl/define-udfs.td.hql
index dd694e3..bc5e3db 100644
--- a/resources/ddl/define-udfs.td.hql
+++ b/resources/ddl/define-udfs.td.hql
@@ -174,6 +174,8 @@ create temporary function dimsum_mapper as 'hivemall.knn.similarity.DIMSUMMapper
create temporary function train_classifier as 'hivemall.classifier.GeneralClassifierUDTF';
create temporary function train_regressor as 'hivemall.regression.GeneralRegressorUDTF';
create temporary function tree_export as 'hivemall.smile.tools.TreeExportUDF';
+create temporary function train_ffm as 'hivemall.fm.FieldAwareFactorizationMachineUDTF';
+create temporary function ffm_predict as 'hivemall.fm.FFMPredictGenericUDAF';
-- NLP features
create temporary function tokenize_ja as 'hivemall.nlp.tokenizer.KuromojiUDF';
[5/8] incubator-hivemall git commit: Removed unnessasary code
Posted by my...@apache.org.
Removed unnessasary code
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/36a6ca27
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/36a6ca27
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/36a6ca27
Branch: refs/heads/HIVEMALL-24-2
Commit: 36a6ca27a48de0e0a89bef352d0e2c23179091c8
Parents: deea39a
Author: Makoto Yui <my...@apache.org>
Authored: Mon Jul 24 16:58:33 2017 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Mon Jul 24 16:58:33 2017 +0900
----------------------------------------------------------------------
core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java | 2 --
1 file changed, 2 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/36a6ca27/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java b/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java
index d814b40..10e7614 100644
--- a/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java
+++ b/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java
@@ -265,7 +265,6 @@ public final class FeaturePairsUDTF extends UDTFWithOptions {
final Feature[] features = Feature.parseFFMFeatures(arg, fvOI, null, _numFeatures, _numFields);
// W0
- forward[0] = f0;
f0.set(0);
forward[1] = null;
forward[2] = null;
@@ -279,7 +278,6 @@ public final class FeaturePairsUDTF extends UDTFWithOptions {
int iField = ei.getField();
// Wi
- forward[0] = f0;
f0.set(i);
forward[1] = null;
f2.set(xi);
[4/8] incubator-hivemall git commit: safe cast for Short
Posted by my...@apache.org.
safe cast for Short
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/deea39a2
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/deea39a2
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/deea39a2
Branch: refs/heads/HIVEMALL-24-2
Commit: deea39a25653d0d85993ca50236003a0db13c202
Parents: 82fa581
Author: Makoto Yui <my...@apache.org>
Authored: Mon Jul 24 15:22:16 2017 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Mon Jul 24 15:22:16 2017 +0900
----------------------------------------------------------------------
core/src/main/java/hivemall/fm/Feature.java | 9 +++++----
.../src/main/java/hivemall/utils/lang/Primitives.java | 14 +++++++-------
2 files changed, 12 insertions(+), 11 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/deea39a2/core/src/main/java/hivemall/fm/Feature.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/Feature.java b/core/src/main/java/hivemall/fm/Feature.java
index 2966a02..f77d3fd 100644
--- a/core/src/main/java/hivemall/fm/Feature.java
+++ b/core/src/main/java/hivemall/fm/Feature.java
@@ -18,6 +18,7 @@
*/
package hivemall.fm;
+import static hivemall.utils.lang.Primitives.castToShort;
import hivemall.utils.hashing.MurmurHash3;
import hivemall.utils.lang.NumberUtils;
@@ -219,7 +220,7 @@ public abstract class Feature {
} else {
index = MurmurHash3.murmurhash3(lead, numFields);
}
- short field = (short) index;
+ short field = castToShort(index);
double value = parseFeatureValue(rest);
return new IntFeature(index, field, value);
}
@@ -237,7 +238,7 @@ public abstract class Feature {
} else {
// +NUM_FIELD to avoid conflict to quantitative features
index = MurmurHash3.murmurhash3(indexStr, numFeatures) + numFields;
- field = (short) MurmurHash3.murmurhash3(lead, numFields);
+ field = castToShort(MurmurHash3.murmurhash3(lead, numFields));
}
String valueStr = rest.substring(pos2 + 1);
double value = parseFeatureValue(valueStr);
@@ -296,7 +297,7 @@ public abstract class Feature {
} else {
index = MurmurHash3.murmurhash3(lead, numFields);
}
- short field = (short) index;
+ short field = castToShort(index);
probe.setField(field);
probe.setFeatureIndex(index);
probe.value = parseFeatureValue(rest);
@@ -316,7 +317,7 @@ public abstract class Feature {
} else {
// +NUM_FIELD to avoid conflict to quantitative features
index = MurmurHash3.murmurhash3(indexStr, numFeatures) + numFields;
- field = (short) MurmurHash3.murmurhash3(lead, numFields);
+ field = castToShort(MurmurHash3.murmurhash3(lead, numFields));
}
probe.setField(field);
probe.setFeatureIndex(index);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/deea39a2/core/src/main/java/hivemall/utils/lang/Primitives.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/lang/Primitives.java b/core/src/main/java/hivemall/utils/lang/Primitives.java
index 2ec012c..a6e945f 100644
--- a/core/src/main/java/hivemall/utils/lang/Primitives.java
+++ b/core/src/main/java/hivemall/utils/lang/Primitives.java
@@ -92,16 +92,16 @@ public final class Primitives {
b[off] = (byte) (val >>> 8);
}
- public static int toIntExact(final long longValue) {
- final int casted = (int) longValue;
- if (casted != longValue) {
- throw new ArithmeticException("integer overflow: " + longValue);
+ public static int castToInt(final long value) {
+ final int result = (int) value;
+ if (result != value) {
+ throw new IllegalArgumentException("Out of range: " + value);
}
- return casted;
+ return result;
}
- public static int castToInt(final long value) {
- final int result = (int) value;
+ public static short castToShort(final int value) {
+ final short result = (short) value;
if (result != value) {
throw new IllegalArgumentException("Out of range: " + value);
}
[7/8] incubator-hivemall git commit: Commented out Oracle JDK 7 from
CI because it's EoL
Posted by my...@apache.org.
Commented out Oracle JDK 7 from CI because it's EoL
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/357323a5
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/357323a5
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/357323a5
Branch: refs/heads/HIVEMALL-24-2
Commit: 357323a522c594830f6cdfeff4d5c115c6f28dd9
Parents: 550bb4e
Author: Makoto Yui <my...@apache.org>
Authored: Wed Jul 26 01:01:07 2017 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Wed Jul 26 01:01:07 2017 +0900
----------------------------------------------------------------------
.travis.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/357323a5/.travis.yml
----------------------------------------------------------------------
diff --git a/.travis.yml b/.travis.yml
index 323e36a..96f8f4e 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -19,7 +19,7 @@ env:
language: java
jdk:
- openjdk7
- - oraclejdk7
+# - oraclejdk7
- oraclejdk8
branches:
[8/8] incubator-hivemall git commit: Fixed XX:PermSize to
XX:MetaspaceSize for Java 8
Posted by my...@apache.org.
Fixed XX:PermSize to XX:MetaspaceSize for Java 8
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/f70e7c52
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/f70e7c52
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/f70e7c52
Branch: refs/heads/HIVEMALL-24-2
Commit: f70e7c52ec1ee88034c71338bf9fc614197d70f8
Parents: 357323a
Author: Makoto Yui <my...@apache.org>
Authored: Wed Jul 26 14:48:43 2017 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Wed Jul 26 14:48:43 2017 +0900
----------------------------------------------------------------------
pom.xml | 18 ++++++++++++++++++
spark/spark-2.0/pom.xml | 10 ++--------
spark/spark-2.1/pom.xml | 10 ++--------
spark/spark-common/pom.xml | 8 +-------
4 files changed, 23 insertions(+), 23 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f70e7c52/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index 8b6aa5e..a7dc301 100644
--- a/pom.xml
+++ b/pom.xml
@@ -288,6 +288,24 @@
<spark.binary.version>2.0</spark.binary.version>
</properties>
</profile>
+ <profile>
+ <id>java7</id>
+ <properties>
+ <spark.test.jvm.opts>-ea -Xms1024m -Xmx2048m -XX:PermSize=128m -XX:MaxPermSize=512m -XX:ReservedCodeCacheSize=128m</spark.test.jvm.opts>
+ </properties>
+ <activation>
+ <jdk>[,1.8)</jdk> <!-- version < 1.8 -->
+ </activation>
+ </profile>
+ <profile>
+ <id>java8</id>
+ <properties>
+ <spark.test.jvm.opts>-ea -Xms1024m -Xmx2048m -XX:MetaspaceSize=128m -XX:MaxMetaspaceSize=512m -XX:ReservedCodeCacheSize=128m</spark.test.jvm.opts>
+ </properties>
+ <activation>
+ <jdk>[1.8,)</jdk> <!-- version >= 1.8 -->
+ </activation>
+ </profile>
<profile>
<id>compile-xgboost</id>
<build>
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f70e7c52/spark/spark-2.0/pom.xml
----------------------------------------------------------------------
diff --git a/spark/spark-2.0/pom.xml b/spark/spark-2.0/pom.xml
index 123c424..74e9348 100644
--- a/spark/spark-2.0/pom.xml
+++ b/spark/spark-2.0/pom.xml
@@ -32,9 +32,6 @@
<packaging>jar</packaging>
<properties>
- <PermGen>64m</PermGen>
- <MaxPermGen>512m</MaxPermGen>
- <CodeCacheSize>512m</CodeCacheSize>
<main.basedir>${project.parent.basedir}</main.basedir>
</properties>
@@ -164,11 +161,8 @@
<!-- <arg>-feature</arg> -->
</args>
<jvmArgs>
- <jvmArg>-Xms1024m</jvmArg>
+ <jvmArg>-Xms512m</jvmArg>
<jvmArg>-Xmx1024m</jvmArg>
- <jvmArg>-XX:PermSize=${PermGen}</jvmArg>
- <jvmArg>-XX:MaxPermSize=${MaxPermGen}</jvmArg>
- <jvmArg>-XX:ReservedCodeCacheSize=${CodeCacheSize}</jvmArg>
</jvmArgs>
</configuration>
</plugin>
@@ -233,7 +227,7 @@
<reportsDirectory>${project.build.directory}/surefire-reports</reportsDirectory>
<junitxml>.</junitxml>
<filereports>SparkTestSuite.txt</filereports>
- <argLine>-ea -Xmx2g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=${CodeCacheSize}</argLine>
+ <argLine>${spark.test.jvm.opts}</argLine>
<stderr />
<environmentVariables>
<SPARK_PREPEND_CLASSES>1</SPARK_PREPEND_CLASSES>
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f70e7c52/spark/spark-2.1/pom.xml
----------------------------------------------------------------------
diff --git a/spark/spark-2.1/pom.xml b/spark/spark-2.1/pom.xml
index 22d3e12..d7ab81a 100644
--- a/spark/spark-2.1/pom.xml
+++ b/spark/spark-2.1/pom.xml
@@ -32,9 +32,6 @@
<packaging>jar</packaging>
<properties>
- <PermGen>64m</PermGen>
- <MaxPermGen>512m</MaxPermGen>
- <CodeCacheSize>512m</CodeCacheSize>
<main.basedir>${project.parent.basedir}</main.basedir>
</properties>
@@ -164,11 +161,8 @@
<!-- <arg>-feature</arg> -->
</args>
<jvmArgs>
- <jvmArg>-Xms1024m</jvmArg>
+ <jvmArg>-Xms512m</jvmArg>
<jvmArg>-Xmx1024m</jvmArg>
- <jvmArg>-XX:PermSize=${PermGen}</jvmArg>
- <jvmArg>-XX:MaxPermSize=${MaxPermGen}</jvmArg>
- <jvmArg>-XX:ReservedCodeCacheSize=${CodeCacheSize}</jvmArg>
</jvmArgs>
</configuration>
</plugin>
@@ -233,7 +227,7 @@
<reportsDirectory>${project.build.directory}/surefire-reports</reportsDirectory>
<junitxml>.</junitxml>
<filereports>SparkTestSuite.txt</filereports>
- <argLine>-ea -Xmx2g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=${CodeCacheSize}</argLine>
+ <argLine>${spark.test.jvm.opts}</argLine>
<stderr />
<environmentVariables>
<SPARK_PREPEND_CLASSES>1</SPARK_PREPEND_CLASSES>
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f70e7c52/spark/spark-common/pom.xml
----------------------------------------------------------------------
diff --git a/spark/spark-common/pom.xml b/spark/spark-common/pom.xml
index e8e8ff4..3153a75 100644
--- a/spark/spark-common/pom.xml
+++ b/spark/spark-common/pom.xml
@@ -32,9 +32,6 @@
<packaging>jar</packaging>
<properties>
- <PermGen>64m</PermGen>
- <MaxPermGen>1024m</MaxPermGen>
- <CodeCacheSize>512m</CodeCacheSize>
<main.basedir>${project.parent.basedir}</main.basedir>
</properties>
@@ -138,11 +135,8 @@
<!-- <arg>-feature</arg> -->
</args>
<jvmArgs>
- <jvmArg>-Xms1024m</jvmArg>
+ <jvmArg>-Xms512m</jvmArg>
<jvmArg>-Xmx1024m</jvmArg>
- <jvmArg>-XX:PermSize=${PermGen}</jvmArg>
- <jvmArg>-XX:MaxPermSize=${MaxPermGen}</jvmArg>
- <jvmArg>-XX:ReservedCodeCacheSize=${CodeCacheSize}</jvmArg>
</jvmArgs>
</configuration>
</plugin>
[2/8] incubator-hivemall git commit: Create FFMPredictGenericUDAF
Posted by my...@apache.org.
Create FFMPredictGenericUDAF
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/343f704d
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/343f704d
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/343f704d
Branch: refs/heads/HIVEMALL-24-2
Commit: 343f704dc4eeee25f195e298ee014368db1fef9e
Parents: 591e3b0
Author: Takuya Kitazawa <k....@gmail.com>
Authored: Fri Mar 3 17:48:26 2017 +0900
Committer: Takuya Kitazawa <k....@gmail.com>
Committed: Fri Mar 3 17:48:26 2017 +0900
----------------------------------------------------------------------
.../java/hivemall/fm/FFMPredictGenericUDAF.java | 272 +++++++++++++++++++
1 file changed, 272 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/343f704d/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java b/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java
new file mode 100644
index 0000000..91d1b6b
--- /dev/null
+++ b/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java
@@ -0,0 +1,272 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.fm;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.hadoop.WritableUtils;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray;
+import org.apache.hadoop.hive.serde2.objectinspector.*;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableDoubleObjectInspector;
+import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+import java.util.ArrayList;
+import java.util.List;
+
+@Description(
+ name = "ffm_predict",
+ value = "_FUNC_(Float Wi, Float Wj, array<float> Vifj, array<float> Vjfi, float Xi, float Xj)"
+ + " - Returns a prediction value in Double")
+public final class FFMPredictGenericUDAF extends AbstractGenericUDAFResolver {
+
+ private FFMPredictGenericUDAF() {}
+
+ @Override
+ public Evaluator getEvaluator(TypeInfo[] typeInfo) throws SemanticException {
+ if (typeInfo.length != 5) {
+ throw new UDFArgumentLengthException(
+ "Expected argument length is 6 but given argument length was " + typeInfo.length);
+ }
+ if (!HiveUtils.isNumberTypeInfo(typeInfo[0])) {
+ throw new UDFArgumentTypeException(0,
+ "Number type is expected for the first argument Wi: " + typeInfo[0].getTypeName());
+ }
+ if (typeInfo[1].getCategory() != Category.LIST) {
+ throw new UDFArgumentTypeException(1,
+ "List type is expected for the second argument Vifj: " + typeInfo[1].getTypeName());
+ }
+ if (typeInfo[2].getCategory() != Category.LIST) {
+ throw new UDFArgumentTypeException(2,
+ "List type is expected for the third argument Vjfi: " + typeInfo[2].getTypeName());
+ }
+ ListTypeInfo typeInfo1 = (ListTypeInfo) typeInfo[1];
+ if (!HiveUtils.isNumberTypeInfo(typeInfo1.getListElementTypeInfo())) {
+ throw new UDFArgumentTypeException(1,
+ "Number type is expected for the element type of list Vifj: "
+ + typeInfo1.getTypeName());
+ }
+ ListTypeInfo typeInfo2 = (ListTypeInfo) typeInfo[2];
+ if (!HiveUtils.isNumberTypeInfo(typeInfo2.getListElementTypeInfo())) {
+ throw new UDFArgumentTypeException(2,
+ "Number type is expected for the element type of list Vjfi: "
+ + typeInfo1.getTypeName());
+ }
+ if (!HiveUtils.isNumberTypeInfo(typeInfo[3])) {
+ throw new UDFArgumentTypeException(3,
+ "Number type is expected for the third argument Xi: " + typeInfo[3].getTypeName());
+ }
+ if (!HiveUtils.isNumberTypeInfo(typeInfo[4])) {
+ throw new UDFArgumentTypeException(4,
+ "Number type is expected for the third argument Xi: " + typeInfo[4].getTypeName());
+ }
+ return new Evaluator();
+ }
+
+ public static class Evaluator extends GenericUDAFEvaluator {
+
+ // input OI
+ private PrimitiveObjectInspector wiOI;
+ private ListObjectInspector vijOI;
+ private ListObjectInspector vjiOI;
+ private PrimitiveObjectInspector xiOI;
+ private PrimitiveObjectInspector xjOI;
+
+ // merge OI
+ private StructObjectInspector internalMergeOI;
+ private StructField sumField;
+
+ public Evaluator() {}
+
+ @Override
+ public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException {
+ assert (parameters.length == 5);
+ super.init(mode, parameters);
+
+ // initialize input
+ if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data
+ this.wiOI = HiveUtils.asDoubleCompatibleOI(parameters[0]);
+ this.vijOI = HiveUtils.asListOI(parameters[1]);
+ this.vjiOI = HiveUtils.asListOI(parameters[2]);
+ this.xiOI = HiveUtils.asDoubleCompatibleOI(parameters[3]);
+ this.xjOI = HiveUtils.asDoubleCompatibleOI(parameters[4]);
+ } else {// from partial aggregation
+ StructObjectInspector soi = (StructObjectInspector) parameters[0];
+ this.internalMergeOI = soi;
+ this.sumField = soi.getStructFieldRef("sum");
+ }
+
+ // initialize output
+ final ObjectInspector outputOI;
+ if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial
+ outputOI = internalMergeOI();
+ } else {
+ outputOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
+ }
+ return outputOI;
+ }
+
+ private static StructObjectInspector internalMergeOI() {
+ ArrayList<String> fieldNames = new ArrayList<String>();
+ ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+
+ fieldNames.add("sum");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+
+ return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ @Override
+ public FFMPredictAggregationBuffer getNewAggregationBuffer() throws HiveException {
+ FFMPredictAggregationBuffer myAggr = new FFMPredictAggregationBuffer();
+ reset(myAggr);
+ return myAggr;
+ }
+
+ @Override
+ public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg)
+ throws HiveException {
+ FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer) agg;
+ myAggr.reset();
+ }
+
+ @Override
+ public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg,
+ Object[] parameters) throws HiveException {
+ if (parameters[0] == null) {
+ return;
+ }
+ FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer) agg;
+
+ double wi = PrimitiveObjectInspectorUtils.getDouble(parameters[0], wiOI);
+ if (parameters[3] == null && parameters[4] == null) {// Xi and Xj are null => global bias `w0`
+ myAggr.addW0(wi);
+ } else if (parameters[4] == null) {// Only Xi is nonnull => linear combination `wi` * `xi`
+ double xi = PrimitiveObjectInspectorUtils.getDouble(parameters[3], xiOI);
+ myAggr.addWiXi(wi, xi);
+ } else {// both Xi and Xj are nonnull => <Vifj, Vjfi> Xi Xj
+ if (parameters[1] == null || parameters[2] == null) {
+ throw new UDFArgumentException("The second and third arguments (Vij, Vji) must not be null");
+ }
+
+ List<Float> vij = (List<Float>) vijOI.getList(parameters[1]);
+ List<Float> vji = (List<Float>) vjiOI.getList(parameters[2]);
+
+ if (vij.size() != vji.size()) {
+ throw new HiveException("Mismatch in the number of factors");
+ }
+
+ double xi = PrimitiveObjectInspectorUtils.getDouble(parameters[3], xiOI);
+ double xj = PrimitiveObjectInspectorUtils.getDouble(parameters[4], xjOI);
+
+ myAggr.addViVjXiXj(vij, vji, xi, xj);
+ }
+ }
+
+ @Override
+ public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg)
+ throws HiveException {
+ FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer) agg;
+
+ final Object[] partialResult = new Object[1];
+ return partialResult;
+ }
+
+ @Override
+ public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial)
+ throws HiveException {
+ if (partial == null) {
+ return;
+ }
+
+ Object sumObj = internalMergeOI.getStructFieldData(partial, sumField);
+ double sum = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(sumObj);
+
+ FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer) agg;
+ myAggr.merge(sum);
+ }
+
+ @Override
+ public DoubleWritable terminate(@SuppressWarnings("deprecation") AggregationBuffer agg)
+ throws HiveException {
+ FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer) agg;
+ double result = myAggr.get();
+ return new DoubleWritable(result);
+ }
+
+ }
+
+ public static class FFMPredictAggregationBuffer extends AbstractAggregationBuffer {
+
+ double sum;
+
+ FFMPredictAggregationBuffer() {
+ super();
+ }
+
+ void reset() {
+ this.sum = 0.d;
+ }
+
+ void merge(final double o_sum) {
+ sum += o_sum;
+ }
+
+ double get() {
+ return sum;
+ }
+
+ void addW0(final double W0) {
+ sum += W0;
+ }
+
+ void addWiXi(final double Wi, final double Xi) {
+ sum += Wi * Xi;
+ }
+
+ void addViVjXiXj(@Nonnull final List<Float> Vifj, @Nonnull final List<Float> Vjfi,
+ final double Xi, final double Xj) {
+ final int factors = Vifj.size();
+ double prod = 0.d;
+
+ // compute inner product <Vifj, Vjfi>
+ for (int i = 0; i < factors; i++) {
+ prod += (Vifj.get(i) * Vjfi.get(i));
+ }
+
+ sum += (prod * Xi * Xj);
+ }
+
+ }
+
+}