You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/09/28 03:17:29 UTC
[3/3] incubator-hivemall git commit: Close #117,
Close #111: [HIVEMALL-17] Support SLIM neighborhood-learning
recommendation algorithm
Close #117, Close #111: [HIVEMALL-17] Support SLIM neighborhood-learning recommendation algorithm
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/995b9a88
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/995b9a88
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/995b9a88
Branch: refs/heads/master
Commit: 995b9a885f6538138935dbf0fe9aae051ec47f9e
Parents: c2b9578
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Authored: Thu Sep 28 12:16:17 2017 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Thu Sep 28 12:16:45 2017 +0900
----------------------------------------------------------------------
.../main/java/hivemall/evaluation/AUCUDAF.java | 17 +-
.../evaluation/BinaryResponsesMeasures.java | 37 +-
.../evaluation/GradedResponsesMeasures.java | 16 +-
.../java/hivemall/evaluation/HitRateUDAF.java | 262 +++++++
.../main/java/hivemall/evaluation/MAPUDAF.java | 19 +-
.../main/java/hivemall/evaluation/MRRUDAF.java | 19 +-
.../main/java/hivemall/evaluation/NDCGUDAF.java | 17 +-
.../java/hivemall/evaluation/PrecisionUDAF.java | 24 +-
.../java/hivemall/evaluation/RecallUDAF.java | 19 +-
.../hivemall/math/matrix/sparse/CSCMatrix.java | 2 +
.../hivemall/math/matrix/sparse/CSRMatrix.java | 4 +-
.../math/matrix/sparse/DoKFloatMatrix.java | 368 +++++++++
.../hivemall/math/matrix/sparse/DoKMatrix.java | 34 +-
.../hivemall/math/vector/VectorProcedure.java | 6 +
.../hivemall/mf/BPRMatrixFactorizationUDTF.java | 3 +-
.../mf/OnlineMatrixFactorizationUDTF.java | 7 +-
.../main/java/hivemall/recommend/SlimUDTF.java | 759 +++++++++++++++++++
.../maps/Int2DoubleOpenHashTable.java | 427 +++++++++++
.../maps/Int2FloatOpenHashTable.java | 71 +-
.../collections/maps/Int2IntOpenHashTable.java | 5 +-
.../collections/maps/IntOpenHashTable.java | 5 +-
.../maps/Long2DoubleOpenHashTable.java | 3 +
.../maps/Long2FloatOpenHashTable.java | 23 +-
.../collections/maps/Long2IntOpenHashTable.java | 3 +
.../utils/collections/maps/OpenHashTable.java | 5 +-
.../utils/lang/mutable/MutableObject.java | 83 ++
.../java/hivemall/utils/math/MathUtils.java | 2 +-
.../evaluation/BinaryResponsesMeasuresTest.java | 18 +-
.../evaluation/GradedResponsesMeasuresTest.java | 6 +-
.../hivemall/math/matrix/MatrixBuilderTest.java | 1 -
.../math/matrix/sparse/DoKFloatMatrixTest.java | 43 ++
.../java/hivemall/recommend/SlimUDTFTest.java | 99 +++
docs/gitbook/SUMMARY.md | 1 +
docs/gitbook/recommend/item_based_cf.md | 8 +-
docs/gitbook/recommend/movielens_cf.md | 3 +-
docs/gitbook/recommend/movielens_cv.md | 2 +-
docs/gitbook/recommend/movielens_fm.md | 4 +-
docs/gitbook/recommend/movielens_slim.md | 589 ++++++++++++++
resources/ddl/define-all-as-permanent.hive | 10 +
resources/ddl/define-all.hive | 10 +
resources/ddl/define-all.spark | 10 +
resources/ddl/define-udfs.td.hql | 2 +
42 files changed, 2916 insertions(+), 130 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/AUCUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/evaluation/AUCUDAF.java b/core/src/main/java/hivemall/evaluation/AUCUDAF.java
index 7cbdb52..508e36a 100644
--- a/core/src/main/java/hivemall/evaluation/AUCUDAF.java
+++ b/core/src/main/java/hivemall/evaluation/AUCUDAF.java
@@ -52,7 +52,6 @@ import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;
@@ -430,7 +429,7 @@ public final class AUCUDAF extends AbstractGenericUDAFResolver {
private ListObjectInspector recommendListOI;
private ListObjectInspector truthListOI;
- private WritableIntObjectInspector recommendSizeOI;
+ private PrimitiveObjectInspector recommendSizeOI;
private StructObjectInspector internalMergeOI;
private StructField countField;
@@ -448,7 +447,7 @@ public final class AUCUDAF extends AbstractGenericUDAFResolver {
this.recommendListOI = (ListObjectInspector) parameters[0];
this.truthListOI = (ListObjectInspector) parameters[1];
if (parameters.length == 3) {
- this.recommendSizeOI = (WritableIntObjectInspector) parameters[2];
+ this.recommendSizeOI = HiveUtils.asIntegerOI(parameters[2]);
}
} else {// from partial aggregation
StructObjectInspector soi = (StructObjectInspector) parameters[0];
@@ -507,12 +506,12 @@ public final class AUCUDAF extends AbstractGenericUDAFResolver {
int recommendSize = recommendList.size();
if (parameters.length == 3) {
- recommendSize = recommendSizeOI.get(parameters[2]);
- }
- if (recommendSize < 0 || recommendSize > recommendList.size()) {
- throw new UDFArgumentException(
- "The third argument `int recommendSize` must be in [0, " + recommendList.size()
- + "]");
+ recommendSize = PrimitiveObjectInspectorUtils.getInt(parameters[2], recommendSizeOI);
+ if (recommendSize < 0) {
+ throw new UDFArgumentException(
+ "The third argument `int recommendSize` must be in greather than or equals to 0: "
+ + recommendSize);
+ }
}
myAggr.iterate(recommendList, truthList, recommendSize);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java b/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java
index 7c21849..c3b4f6a 100644
--- a/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java
+++ b/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java
@@ -45,7 +45,7 @@ public final class BinaryResponsesMeasures {
*/
public static double nDCG(@Nonnull final List<?> rankedList,
@Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) {
- Preconditions.checkArgument(recommendSize > 0);
+ Preconditions.checkArgument(recommendSize >= 0);
double dcg = 0.d;
@@ -92,6 +92,8 @@ public final class BinaryResponsesMeasures {
*/
public static double Precision(@Nonnull final List<?> rankedList,
@Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) {
+ Preconditions.checkArgument(recommendSize >= 0);
+
if (rankedList.isEmpty()) {
if (groundTruth.isEmpty()) {
return 1.d;
@@ -99,8 +101,6 @@ public final class BinaryResponsesMeasures {
return 0.d;
}
- Preconditions.checkArgument(recommendSize > 0); // can be zero when groundTruth is empty
-
int nTruePositive = 0;
final int k = Math.min(rankedList.size(), recommendSize);
for (int i = 0; i < k; i++) {
@@ -135,6 +135,29 @@ public final class BinaryResponsesMeasures {
}
/**
+ * Computes Hit@`recommendSize`
+ *
+ * @param rankedList a list of ranked item IDs (first item is highest-ranked)
+ * @param groundTruth a collection of positive/correct item IDs
+ * @param recommendSize top-`recommendSize` items in `rankedList` are recommended
+ * @return 1.0 if hit 0.0 if no hit
+ */
+ public static double Hit(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth,
+ @Nonnegative final int recommendSize) {
+ Preconditions.checkArgument(recommendSize >= 0);
+
+ final int k = Math.min(rankedList.size(), recommendSize);
+ for (int i = 0; i < k; i++) {
+ Object item_id = rankedList.get(i);
+ if (groundTruth.contains(item_id)) {
+ return 1.d;
+ }
+ }
+
+ return 0.d;
+ }
+
+ /**
* Counts the number of true positives
*
* @param rankedList a list of ranked item IDs (first item is highest-ranked)
@@ -144,7 +167,7 @@ public final class BinaryResponsesMeasures {
*/
public static int TruePositives(final List<?> rankedList, final List<?> groundTruth,
@Nonnegative final int recommendSize) {
- Preconditions.checkArgument(recommendSize > 0);
+ Preconditions.checkArgument(recommendSize >= 0);
int nTruePositive = 0;
@@ -170,7 +193,7 @@ public final class BinaryResponsesMeasures {
*/
public static double ReciprocalRank(@Nonnull final List<?> rankedList,
@Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) {
- Preconditions.checkArgument(recommendSize > 0);
+ Preconditions.checkArgument(recommendSize >= 0);
final int k = Math.min(rankedList.size(), recommendSize);
for (int i = 0; i < k; i++) {
@@ -193,7 +216,7 @@ public final class BinaryResponsesMeasures {
*/
public static double AveragePrecision(@Nonnull final List<?> rankedList,
@Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) {
- Preconditions.checkArgument(recommendSize > 0);
+ Preconditions.checkArgument(recommendSize >= 0);
if (groundTruth.isEmpty()) {
if (rankedList.isEmpty()) {
@@ -231,7 +254,7 @@ public final class BinaryResponsesMeasures {
*/
public static double AUC(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth,
@Nonnegative final int recommendSize) {
- Preconditions.checkArgument(recommendSize > 0);
+ Preconditions.checkArgument(recommendSize >= 0);
int nTruePositive = 0, nCorrectPairs = 0;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/GradedResponsesMeasures.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/evaluation/GradedResponsesMeasures.java b/core/src/main/java/hivemall/evaluation/GradedResponsesMeasures.java
index 688ba53..5bbbb7e 100644
--- a/core/src/main/java/hivemall/evaluation/GradedResponsesMeasures.java
+++ b/core/src/main/java/hivemall/evaluation/GradedResponsesMeasures.java
@@ -18,8 +18,12 @@
*/
package hivemall.evaluation;
+import hivemall.utils.lang.Preconditions;
+import hivemall.utils.math.MathUtils;
+
import java.util.List;
+import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
/**
@@ -32,7 +36,7 @@ public final class GradedResponsesMeasures {
private GradedResponsesMeasures() {}
public static double nDCG(@Nonnull final List<Double> recommendTopRelScoreList,
- @Nonnull final List<Double> truthTopRelScoreList, @Nonnull final int recommendSize) {
+ @Nonnull final List<Double> truthTopRelScoreList, @Nonnegative final int recommendSize) {
double dcg = DCG(recommendTopRelScoreList, recommendSize);
double idcg = DCG(truthTopRelScoreList, recommendSize);
return dcg / idcg;
@@ -45,11 +49,15 @@ public final class GradedResponsesMeasures {
* @param recommendSize the number of positive items
* @return DCG
*/
- public static double DCG(final List<Double> topRelScoreList, final int recommendSize) {
+ public static double DCG(@Nonnull final List<Double> topRelScoreList,
+ @Nonnegative final int recommendSize) {
+ Preconditions.checkArgument(recommendSize >= 0);
+
double dcg = 0.d;
- for (int i = 0; i < recommendSize; i++) {
+ final int k = Math.min(topRelScoreList.size(), recommendSize);
+ for (int i = 0; i < k; i++) {
double relScore = topRelScoreList.get(i);
- dcg += ((Math.pow(2, relScore) - 1) * Math.log(2)) / Math.log(i + 2);
+ dcg += ((Math.pow(2, relScore) - 1) * MathUtils.LOG2) / Math.log(i + 2);
}
return dcg;
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/HitRateUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/evaluation/HitRateUDAF.java b/core/src/main/java/hivemall/evaluation/HitRateUDAF.java
new file mode 100644
index 0000000..6df6087
--- /dev/null
+++ b/core/src/main/java/hivemall/evaluation/HitRateUDAF.java
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied. See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
+package hivemall.evaluation;
+
+import hivemall.utils.hadoop.HiveUtils;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructField;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
+import org.apache.hadoop.io.LongWritable;
+
+@Description(
+ name = "hitrate",
+ value = "_FUNC_(array rankItems, array correctItems [, const int recommendSize = rankItems.size])"
+ + " - Returns HitRate")
+public final class HitRateUDAF extends AbstractGenericUDAFResolver {
+
+ // prevent instantiation
+ private HitRateUDAF() {}
+
+ @Override
+ public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfo) throws SemanticException {
+ if (typeInfo.length != 2 && typeInfo.length != 3) {
+ throw new UDFArgumentTypeException(typeInfo.length - 1,
+ "_FUNC_ takes two or three arguments");
+ }
+
+ ListTypeInfo arg1type = HiveUtils.asListTypeInfo(typeInfo[0]);
+ if (!HiveUtils.isPrimitiveTypeInfo(arg1type.getListElementTypeInfo())) {
+ throw new UDFArgumentTypeException(0,
+ "The first argument `array rankItems` is invalid form: " + typeInfo[0]);
+ }
+ ListTypeInfo arg2type = HiveUtils.asListTypeInfo(typeInfo[1]);
+ if (!HiveUtils.isPrimitiveTypeInfo(arg2type.getListElementTypeInfo())) {
+ throw new UDFArgumentTypeException(1,
+ "The second argument `array correctItems` is invalid form: " + typeInfo[1]);
+ }
+
+ return new HitRateUDAF.Evaluator();
+ }
+
+ public static class Evaluator extends GenericUDAFEvaluator {
+
+ private ListObjectInspector recommendListOI;
+ private ListObjectInspector truthListOI;
+ private PrimitiveObjectInspector recommendSizeOI;
+
+ private StructObjectInspector internalMergeOI;
+ private StructField countField;
+ private StructField sumField;
+
+ public Evaluator() {}
+
+ @Override
+ public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException {
+ assert (parameters.length == 2 || parameters.length == 3) : parameters.length;
+ super.init(mode, parameters);
+
+ // initialize input
+ if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data
+ this.recommendListOI = (ListObjectInspector) parameters[0];
+ this.truthListOI = (ListObjectInspector) parameters[1];
+ if (parameters.length == 3) {
+ this.recommendSizeOI = HiveUtils.asIntegerOI(parameters[2]);
+ }
+ } else {// from partial aggregation
+ StructObjectInspector soi = (StructObjectInspector) parameters[0];
+ this.internalMergeOI = soi;
+ this.countField = soi.getStructFieldRef("count");
+ this.sumField = soi.getStructFieldRef("sum");
+ }
+
+ // initialize output
+ final ObjectInspector outputOI;
+ if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial
+ outputOI = internalMergeOI();
+ } else {// terminate
+ outputOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
+ }
+ return outputOI;
+ }
+
+ private static StructObjectInspector internalMergeOI() {
+ ArrayList<String> fieldNames = new ArrayList<String>();
+ ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+
+ fieldNames.add("sum");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+ fieldNames.add("count");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+
+ return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ @Override
+ public HitRateAggregationBuffer getNewAggregationBuffer() throws HiveException {
+ HitRateAggregationBuffer myAggr = new HitRateAggregationBuffer();
+ reset(myAggr);
+ return myAggr;
+ }
+
+ @Override
+ public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg)
+ throws HiveException {
+ HitRateAggregationBuffer myAggr = (HitRateAggregationBuffer) agg;
+ myAggr.reset();
+ }
+
+ @Override
+ public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg,
+ Object[] parameters) throws HiveException {
+ HitRateAggregationBuffer myAggr = (HitRateAggregationBuffer) agg;
+
+ List<?> recommendList = recommendListOI.getList(parameters[0]);
+ if (recommendList == null) {
+ recommendList = Collections.emptyList();
+ }
+ List<?> truthList = truthListOI.getList(parameters[1]);
+ if (truthList == null) {
+ return;
+ }
+
+ int recommendSize = recommendList.size();
+ if (parameters.length == 3) {
+ recommendSize = PrimitiveObjectInspectorUtils.getInt(parameters[2], recommendSizeOI);
+ if (recommendSize < 0) {
+ throw new UDFArgumentException(
+ "The third argument `int recommendSize` must be in greather than or equals to 0: "
+ + recommendSize);
+ }
+ }
+
+ myAggr.iterate(recommendList, truthList, recommendSize);
+ }
+
+ @Override
+ public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg)
+ throws HiveException {
+ HitRateAggregationBuffer myAggr = (HitRateAggregationBuffer) agg;
+
+ Object[] partialResult = new Object[2];
+ partialResult[0] = new DoubleWritable(myAggr.sum);
+ partialResult[1] = new LongWritable(myAggr.count);
+ return partialResult;
+ }
+
+ @Override
+ public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial)
+ throws HiveException {
+ if (partial == null) {
+ return;
+ }
+
+ Object sumObj = internalMergeOI.getStructFieldData(partial, sumField);
+ Object countObj = internalMergeOI.getStructFieldData(partial, countField);
+ double sum = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(sumObj);
+ long count = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(countObj);
+
+ HitRateAggregationBuffer myAggr = (HitRateAggregationBuffer) agg;
+ myAggr.merge(sum, count);
+ }
+
+ @Override
+ public DoubleWritable terminate(@SuppressWarnings("deprecation") AggregationBuffer agg)
+ throws HiveException {
+ HitRateAggregationBuffer myAggr = (HitRateAggregationBuffer) agg;
+ double result = myAggr.get();
+ return new DoubleWritable(result);
+ }
+
+ }
+
+ public static final class HitRateAggregationBuffer extends
+ GenericUDAFEvaluator.AbstractAggregationBuffer {
+
+ private double sum;
+ private long count;
+
+ public HitRateAggregationBuffer() {
+ super();
+ }
+
+ void reset() {
+ this.sum = 0.d;
+ this.count = 0;
+ }
+
+ void merge(double o_sum, long o_count) {
+ this.sum += o_sum;
+ this.count += o_count;
+ }
+
+ double get() {
+ if (count == 0) {
+ return 0.d;
+ }
+ return sum / count;
+ }
+
+ void iterate(@Nonnull List<?> recommendList, @Nonnull List<?> truthList,
+ @Nonnegative int recommendSize) {
+ this.sum += BinaryResponsesMeasures.Hit(recommendList, truthList, recommendSize);
+ this.count++;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/MAPUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/evaluation/MAPUDAF.java b/core/src/main/java/hivemall/evaluation/MAPUDAF.java
index 3878684..45e64cb 100644
--- a/core/src/main/java/hivemall/evaluation/MAPUDAF.java
+++ b/core/src/main/java/hivemall/evaluation/MAPUDAF.java
@@ -38,10 +38,11 @@ import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;
@@ -80,7 +81,7 @@ public final class MAPUDAF extends AbstractGenericUDAFResolver {
private ListObjectInspector recommendListOI;
private ListObjectInspector truthListOI;
- private WritableIntObjectInspector recommendSizeOI;
+ private PrimitiveObjectInspector recommendSizeOI;
private StructObjectInspector internalMergeOI;
private StructField countField;
@@ -98,7 +99,7 @@ public final class MAPUDAF extends AbstractGenericUDAFResolver {
this.recommendListOI = (ListObjectInspector) parameters[0];
this.truthListOI = (ListObjectInspector) parameters[1];
if (parameters.length == 3) {
- this.recommendSizeOI = (WritableIntObjectInspector) parameters[2];
+ this.recommendSizeOI = HiveUtils.asIntegerOI(parameters[2]);
}
} else {// from partial aggregation
StructObjectInspector soi = (StructObjectInspector) parameters[0];
@@ -159,12 +160,12 @@ public final class MAPUDAF extends AbstractGenericUDAFResolver {
int recommendSize = recommendList.size();
if (parameters.length == 3) {
- recommendSize = recommendSizeOI.get(parameters[2]);
- }
- if (recommendSize < 0 || recommendSize > recommendList.size()) {
- throw new UDFArgumentException(
- "The third argument `int recommendSize` must be in [0, " + recommendList.size()
- + "]");
+ recommendSize = PrimitiveObjectInspectorUtils.getInt(parameters[2], recommendSizeOI);
+ if (recommendSize < 0) {
+ throw new UDFArgumentException(
+ "The third argument `int recommendSize` must be in greather than or equals to 0: "
+ + recommendSize);
+ }
}
myAggr.iterate(recommendList, truthList, recommendSize);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/MRRUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/evaluation/MRRUDAF.java b/core/src/main/java/hivemall/evaluation/MRRUDAF.java
index f5aba3b..98b8c3d 100644
--- a/core/src/main/java/hivemall/evaluation/MRRUDAF.java
+++ b/core/src/main/java/hivemall/evaluation/MRRUDAF.java
@@ -38,10 +38,11 @@ import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;
@@ -80,7 +81,7 @@ public final class MRRUDAF extends AbstractGenericUDAFResolver {
private ListObjectInspector recommendListOI;
private ListObjectInspector truthListOI;
- private WritableIntObjectInspector recommendSizeOI;
+ private PrimitiveObjectInspector recommendSizeOI;
private StructObjectInspector internalMergeOI;
private StructField countField;
@@ -98,7 +99,7 @@ public final class MRRUDAF extends AbstractGenericUDAFResolver {
this.recommendListOI = (ListObjectInspector) parameters[0];
this.truthListOI = (ListObjectInspector) parameters[1];
if (parameters.length == 3) {
- this.recommendSizeOI = (WritableIntObjectInspector) parameters[2];
+ this.recommendSizeOI = HiveUtils.asIntegerOI(parameters[2]);
}
} else {// from partial aggregation
StructObjectInspector soi = (StructObjectInspector) parameters[0];
@@ -159,12 +160,12 @@ public final class MRRUDAF extends AbstractGenericUDAFResolver {
int recommendSize = recommendList.size();
if (parameters.length == 3) {
- recommendSize = recommendSizeOI.get(parameters[2]);
- }
- if (recommendSize < 0 || recommendSize > recommendList.size()) {
- throw new UDFArgumentException(
- "The third argument `int recommendSize` must be in [0, " + recommendList.size()
- + "]");
+ recommendSize = PrimitiveObjectInspectorUtils.getInt(parameters[2], recommendSizeOI);
+ if (recommendSize < 0) {
+ throw new UDFArgumentException(
+ "The third argument `int recommendSize` must be in greather than or equals to 0: "
+ + recommendSize);
+ }
}
myAggr.iterate(recommendList, truthList, recommendSize);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/NDCGUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/evaluation/NDCGUDAF.java b/core/src/main/java/hivemall/evaluation/NDCGUDAF.java
index f1ba832..4e4fde6 100644
--- a/core/src/main/java/hivemall/evaluation/NDCGUDAF.java
+++ b/core/src/main/java/hivemall/evaluation/NDCGUDAF.java
@@ -45,7 +45,6 @@ import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;
@@ -85,7 +84,7 @@ public final class NDCGUDAF extends AbstractGenericUDAFResolver {
private ListObjectInspector recommendListOI;
private ListObjectInspector truthListOI;
- private WritableIntObjectInspector recommendSizeOI;
+ private PrimitiveObjectInspector recommendSizeOI;
private StructObjectInspector internalMergeOI;
private StructField countField;
@@ -103,7 +102,7 @@ public final class NDCGUDAF extends AbstractGenericUDAFResolver {
this.recommendListOI = (ListObjectInspector) parameters[0];
this.truthListOI = (ListObjectInspector) parameters[1];
if (parameters.length == 3) {
- this.recommendSizeOI = (WritableIntObjectInspector) parameters[2];
+ this.recommendSizeOI = HiveUtils.asIntegerOI(parameters[2]);
}
} else {// from partial aggregation
StructObjectInspector soi = (StructObjectInspector) parameters[0];
@@ -164,12 +163,12 @@ public final class NDCGUDAF extends AbstractGenericUDAFResolver {
int recommendSize = recommendList.size();
if (parameters.length == 3) {
- recommendSize = recommendSizeOI.get(parameters[2]);
- }
- if (recommendSize < 0 || recommendSize > recommendList.size()) {
- throw new UDFArgumentException(
- "The third argument `int recommendSize` must be in [0, " + recommendList.size()
- + "]");
+ recommendSize = PrimitiveObjectInspectorUtils.getInt(parameters[2], recommendSizeOI);
+ if (recommendSize < 0) {
+ throw new UDFArgumentException(
+ "The third argument `int recommendSize` must be in greather than or equals to 0: "
+ + recommendSize);
+ }
}
boolean isBinary = !HiveUtils.isStructOI(recommendListOI.getListElementObjectInspector());
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/PrecisionUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/evaluation/PrecisionUDAF.java b/core/src/main/java/hivemall/evaluation/PrecisionUDAF.java
index 93af519..de8a876 100644
--- a/core/src/main/java/hivemall/evaluation/PrecisionUDAF.java
+++ b/core/src/main/java/hivemall/evaluation/PrecisionUDAF.java
@@ -38,10 +38,11 @@ import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;
@@ -80,7 +81,7 @@ public final class PrecisionUDAF extends AbstractGenericUDAFResolver {
private ListObjectInspector recommendListOI;
private ListObjectInspector truthListOI;
- private WritableIntObjectInspector recommendSizeOI;
+ private PrimitiveObjectInspector recommendSizeOI;
private StructObjectInspector internalMergeOI;
private StructField countField;
@@ -98,7 +99,7 @@ public final class PrecisionUDAF extends AbstractGenericUDAFResolver {
this.recommendListOI = (ListObjectInspector) parameters[0];
this.truthListOI = (ListObjectInspector) parameters[1];
if (parameters.length == 3) {
- this.recommendSizeOI = (WritableIntObjectInspector) parameters[2];
+ this.recommendSizeOI = HiveUtils.asIntegerOI(parameters[2]);
}
} else {// from partial aggregation
StructObjectInspector soi = (StructObjectInspector) parameters[0];
@@ -117,9 +118,10 @@ public final class PrecisionUDAF extends AbstractGenericUDAFResolver {
return outputOI;
}
+ @Nonnull
private static StructObjectInspector internalMergeOI() {
- ArrayList<String> fieldNames = new ArrayList<String>();
- ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+ List<String> fieldNames = new ArrayList<>();
+ List<ObjectInspector> fieldOIs = new ArrayList<>();
fieldNames.add("sum");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
@@ -159,12 +161,12 @@ public final class PrecisionUDAF extends AbstractGenericUDAFResolver {
int recommendSize = recommendList.size();
if (parameters.length == 3) {
- recommendSize = recommendSizeOI.get(parameters[2]);
- }
- if (recommendSize < 0 || recommendSize > recommendList.size()) {
- throw new UDFArgumentException(
- "The third argument `int recommendSize` must be in [0, " + recommendList.size()
- + "]");
+ recommendSize = PrimitiveObjectInspectorUtils.getInt(parameters[2], recommendSizeOI);
+ if (recommendSize < 0) {
+ throw new UDFArgumentException(
+ "The third argument `int recommendSize` must be in greather than or equals to 0: "
+ + recommendSize);
+ }
}
myAggr.iterate(recommendList, truthList, recommendSize);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/RecallUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/evaluation/RecallUDAF.java b/core/src/main/java/hivemall/evaluation/RecallUDAF.java
index fed9f71..30b1712 100644
--- a/core/src/main/java/hivemall/evaluation/RecallUDAF.java
+++ b/core/src/main/java/hivemall/evaluation/RecallUDAF.java
@@ -38,10 +38,11 @@ import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;
@@ -80,7 +81,7 @@ public final class RecallUDAF extends AbstractGenericUDAFResolver {
private ListObjectInspector recommendListOI;
private ListObjectInspector truthListOI;
- private WritableIntObjectInspector recommendSizeOI;
+ private PrimitiveObjectInspector recommendSizeOI;
private StructObjectInspector internalMergeOI;
private StructField countField;
@@ -98,7 +99,7 @@ public final class RecallUDAF extends AbstractGenericUDAFResolver {
this.recommendListOI = (ListObjectInspector) parameters[0];
this.truthListOI = (ListObjectInspector) parameters[1];
if (parameters.length == 3) {
- this.recommendSizeOI = (WritableIntObjectInspector) parameters[2];
+ this.recommendSizeOI = HiveUtils.asIntegerOI(parameters[2]);
}
} else {// from partial aggregation
StructObjectInspector soi = (StructObjectInspector) parameters[0];
@@ -159,12 +160,12 @@ public final class RecallUDAF extends AbstractGenericUDAFResolver {
int recommendSize = recommendList.size();
if (parameters.length == 3) {
- recommendSize = recommendSizeOI.get(parameters[2]);
- }
- if (recommendSize < 0 || recommendSize > recommendList.size()) {
- throw new UDFArgumentException(
- "The third argument `int recommendSize` must be in [0, " + recommendList.size()
- + "]");
+ recommendSize = PrimitiveObjectInspectorUtils.getInt(parameters[2], recommendSizeOI);
+ if (recommendSize < 0) {
+ throw new UDFArgumentException(
+ "The third argument `int recommendSize` must be in greather than or equals to 0: "
+ + recommendSize);
+ }
}
myAggr.iterate(recommendList, truthList, recommendSize);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java
index d2232b2..f8eb02f 100644
--- a/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java
+++ b/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java
@@ -31,6 +31,8 @@ import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
/**
+ * Compressed Sparse Column matrix optimized for colum major access.
+ *
* @link http://netlib.org/linalg/html_templates/node92.html#SECTION00931200000000000000
*/
public final class CSCMatrix extends ColumnMajorMatrix {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java
index dd89521..805bbd1 100644
--- a/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java
+++ b/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java
@@ -29,8 +29,8 @@ import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
/**
- * Read-only CSR double Matrix.
- *
+ * Compressed Sparse Row Matrix optimized for row major access.
+ *
* @link http://netlib.org/linalg/html_templates/node91.html#SECTION00931100000000000000
* @link http://www.cs.colostate.edu/~mcrob/toolbox/c++/sparseMatrix/sparse_matrix_compression.html
*/
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/math/matrix/sparse/DoKFloatMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/DoKFloatMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/DoKFloatMatrix.java
new file mode 100644
index 0000000..16b4b64
--- /dev/null
+++ b/core/src/main/java/hivemall/math/matrix/sparse/DoKFloatMatrix.java
@@ -0,0 +1,368 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.sparse;
+
+import hivemall.annotations.Experimental;
+import hivemall.math.matrix.AbstractMatrix;
+import hivemall.math.matrix.ColumnMajorMatrix;
+import hivemall.math.matrix.RowMajorMatrix;
+import hivemall.math.matrix.builders.DoKMatrixBuilder;
+import hivemall.math.vector.Vector;
+import hivemall.math.vector.VectorProcedure;
+import hivemall.utils.collections.maps.Long2FloatOpenHashTable;
+import hivemall.utils.collections.maps.Long2FloatOpenHashTable.IMapIterator;
+import hivemall.utils.lang.Preconditions;
+import hivemall.utils.lang.Primitives;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+/**
+ * Dictionary Of Keys based sparse matrix.
+ *
+ * This is an efficient structure for constructing a sparse matrix incrementally.
+ */
+@Experimental
+public final class DoKFloatMatrix extends AbstractMatrix {
+
+ @Nonnull
+ private final Long2FloatOpenHashTable elements;
+ @Nonnegative
+ private int numRows;
+ @Nonnegative
+ private int numColumns;
+ @Nonnegative
+ private int nnz;
+
+ public DoKFloatMatrix() {
+ this(0, 0);
+ }
+
+ public DoKFloatMatrix(@Nonnegative int numRows, @Nonnegative int numCols) {
+ this(numRows, numCols, 0.05f);
+ }
+
+ public DoKFloatMatrix(@Nonnegative int numRows, @Nonnegative int numCols,
+ @Nonnegative float sparsity) {
+ super();
+ Preconditions.checkArgument(sparsity >= 0.f && sparsity <= 1.f, "Invalid Sparsity value: "
+ + sparsity);
+ int initialCapacity = Math.max(16384, Math.round(numRows * numCols * sparsity));
+ this.elements = new Long2FloatOpenHashTable(initialCapacity);
+ elements.defaultReturnValue(0.f);
+ this.numRows = numRows;
+ this.numColumns = numCols;
+ this.nnz = 0;
+ }
+
+ public DoKFloatMatrix(@Nonnegative int initSize) {
+ super();
+ int initialCapacity = Math.max(initSize, 16384);
+ this.elements = new Long2FloatOpenHashTable(initialCapacity);
+ elements.defaultReturnValue(0.f);
+ this.numRows = 0;
+ this.numColumns = 0;
+ this.nnz = 0;
+ }
+
+ @Override
+ public boolean isSparse() {
+ return true;
+ }
+
+ @Override
+ public boolean isRowMajorMatrix() {
+ return false;
+ }
+
+ @Override
+ public boolean isColumnMajorMatrix() {
+ return false;
+ }
+
+ @Override
+ public boolean readOnly() {
+ return false;
+ }
+
+ @Override
+ public boolean swappable() {
+ return true;
+ }
+
+ @Override
+ public int nnz() {
+ return nnz;
+ }
+
+ @Override
+ public int numRows() {
+ return numRows;
+ }
+
+ @Override
+ public int numColumns() {
+ return numColumns;
+ }
+
+ @Override
+ public int numColumns(@Nonnegative final int row) {
+ int count = 0;
+ for (int j = 0; j < numColumns; j++) {
+ long index = index(row, j);
+ if (elements.containsKey(index)) {
+ count++;
+ }
+ }
+ return count;
+ }
+
+ @Override
+ public double[] getRow(@Nonnegative final int index) {
+ double[] dst = row();
+ return getRow(index, dst);
+ }
+
+ @Override
+ public double[] getRow(@Nonnegative final int row, @Nonnull final double[] dst) {
+ checkRowIndex(row, numRows);
+
+ final int end = Math.min(dst.length, numColumns);
+ for (int col = 0; col < end; col++) {
+ long k = index(row, col);
+ float v = elements.get(k);
+ dst[col] = v;
+ }
+
+ return dst;
+ }
+
+ @Override
+ public void getRow(@Nonnegative final int index, @Nonnull final Vector row) {
+ checkRowIndex(index, numRows);
+ row.clear();
+
+ for (int col = 0; col < numColumns; col++) {
+ long k = index(index, col);
+ final float v = elements.get(k, 0.f);
+ if (v != 0.f) {
+ row.set(col, v);
+ }
+ }
+ }
+
+ @Override
+ public double get(@Nonnegative final int row, @Nonnegative final int col,
+ final double defaultValue) {
+ return get(row, col, (float) defaultValue);
+ }
+
+ public float get(@Nonnegative final int row, @Nonnegative final int col,
+ final float defaultValue) {
+ long index = index(row, col);
+ return elements.get(index, defaultValue);
+ }
+
+ @Override
+ public void set(@Nonnegative final int row, @Nonnegative final int col, final double value) {
+ set(row, col, (float) value);
+ }
+
+ public void set(@Nonnegative final int row, @Nonnegative final int col, final float value) {
+ checkIndex(row, col);
+
+ final long index = index(row, col);
+ if (value == 0.f && elements.containsKey(index) == false) {
+ return;
+ }
+
+ if (elements.put(index, value, 0.f) == 0.f) {
+ nnz++;
+ this.numRows = Math.max(numRows, row + 1);
+ this.numColumns = Math.max(numColumns, col + 1);
+ }
+ }
+
+ @Override
+ public double getAndSet(@Nonnegative final int row, @Nonnegative final int col,
+ final double value) {
+ return getAndSet(row, col, (float) value);
+ }
+
+ public float getAndSet(@Nonnegative final int row, @Nonnegative final int col, final float value) {
+ checkIndex(row, col);
+
+ final long index = index(row, col);
+ if (value == 0.f && elements.containsKey(index) == false) {
+ return 0.f;
+ }
+
+ final float old = elements.put(index, value, 0.f);
+ if (old == 0.f) {
+ nnz++;
+ this.numRows = Math.max(numRows, row + 1);
+ this.numColumns = Math.max(numColumns, col + 1);
+ }
+ return old;
+ }
+
+ @Override
+ public void swap(@Nonnegative final int row1, @Nonnegative final int row2) {
+ checkRowIndex(row1, numRows);
+ checkRowIndex(row2, numRows);
+
+ for (int j = 0; j < numColumns; j++) {
+ final long i1 = index(row1, j);
+ final long i2 = index(row2, j);
+
+ final int k1 = elements._findKey(i1);
+ final int k2 = elements._findKey(i2);
+
+ if (k1 >= 0) {
+ if (k2 >= 0) {
+ float v1 = elements._get(k1);
+ float v2 = elements._set(k2, v1);
+ elements._set(k1, v2);
+ } else {// k1>=0 and k2<0
+ float v1 = elements._remove(k1);
+ elements.put(i2, v1);
+ }
+ } else if (k2 >= 0) {// k2>=0 and k1 < 0
+ float v2 = elements._remove(k2);
+ elements.put(i1, v2);
+ } else {//k1<0 and k2<0
+ continue;
+ }
+ }
+ }
+
+ @Override
+ public void eachInRow(@Nonnegative final int row, @Nonnull final VectorProcedure procedure,
+ final boolean nullOutput) {
+ checkRowIndex(row, numRows);
+
+ for (int col = 0; col < numColumns; col++) {
+ long i = index(row, col);
+ final int key = elements._findKey(i);
+ if (key < 0) {
+ if (nullOutput) {
+ procedure.apply(col, 0.d);
+ }
+ } else {
+ float v = elements._get(key);
+ procedure.apply(col, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachNonZeroInRow(@Nonnegative final int row,
+ @Nonnull final VectorProcedure procedure) {
+ checkRowIndex(row, numRows);
+
+ for (int col = 0; col < numColumns; col++) {
+ long i = index(row, col);
+ final float v = elements.get(i, 0.f);
+ if (v != 0.f) {
+ procedure.apply(col, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachColumnIndexInRow(int row, VectorProcedure procedure) {
+ checkRowIndex(row, numRows);
+
+ for (int col = 0; col < numColumns; col++) {
+ long i = index(row, col);
+ final int key = elements._findKey(i);
+ if (key != -1) {
+ procedure.apply(col);
+ }
+ }
+ }
+
+ @Override
+ public void eachInColumn(@Nonnegative final int col, @Nonnull final VectorProcedure procedure,
+ final boolean nullOutput) {
+ checkColIndex(col, numColumns);
+
+ for (int row = 0; row < numRows; row++) {
+ long i = index(row, col);
+ final int key = elements._findKey(i);
+ if (key < 0) {
+ if (nullOutput) {
+ procedure.apply(row, 0.d);
+ }
+ } else {
+ float v = elements._get(key);
+ procedure.apply(row, v);
+ }
+ }
+ }
+
+ @Override
+ public void eachNonZeroInColumn(@Nonnegative final int col,
+ @Nonnull final VectorProcedure procedure) {
+ checkColIndex(col, numColumns);
+
+ for (int row = 0; row < numRows; row++) {
+ long i = index(row, col);
+ final float v = elements.get(i, 0.f);
+ if (v != 0.f) {
+ procedure.apply(row, v);
+ }
+ }
+ }
+
+ public void eachNonZeroCell(@Nonnull final VectorProcedure procedure) {
+ if (nnz == 0) {
+ return;
+ }
+ final IMapIterator itor = elements.entries();
+ while (itor.next() != -1) {
+ long k = itor.getKey();
+ int row = Primitives.getHigh(k);
+ int col = Primitives.getLow(k);
+ float value = itor.getValue();
+ procedure.apply(row, col, value);
+ }
+ }
+
+ @Override
+ public RowMajorMatrix toRowMajorMatrix() {
+ throw new UnsupportedOperationException("Not yet supported");
+ }
+
+ @Override
+ public ColumnMajorMatrix toColumnMajorMatrix() {
+ throw new UnsupportedOperationException("Not yet supported");
+ }
+
+ @Override
+ public DoKMatrixBuilder builder() {
+ return new DoKMatrixBuilder(elements.size());
+ }
+
+ @Nonnegative
+ private static long index(@Nonnegative final int row, @Nonnegative final int col) {
+ return Primitives.toLong(row, col);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java
index bcfd152..054d62a 100644
--- a/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java
+++ b/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java
@@ -26,12 +26,18 @@ import hivemall.math.matrix.builders.DoKMatrixBuilder;
import hivemall.math.vector.Vector;
import hivemall.math.vector.VectorProcedure;
import hivemall.utils.collections.maps.Long2DoubleOpenHashTable;
+import hivemall.utils.collections.maps.Long2DoubleOpenHashTable.IMapIterator;
import hivemall.utils.lang.Preconditions;
import hivemall.utils.lang.Primitives;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
+/**
+ * Dictionary Of Keys based sparse matrix.
+ *
+ * This is an efficient structure for constructing a sparse matrix incrementally.
+ */
@Experimental
public final class DoKMatrix extends AbstractMatrix {
@@ -163,8 +169,6 @@ public final class DoKMatrix extends AbstractMatrix {
@Override
public double get(@Nonnegative final int row, @Nonnegative final int col,
final double defaultValue) {
- checkIndex(row, col, numRows, numColumns);
-
long index = index(row, col);
return elements.get(index, defaultValue);
}
@@ -173,11 +177,11 @@ public final class DoKMatrix extends AbstractMatrix {
public void set(@Nonnegative final int row, @Nonnegative final int col, final double value) {
checkIndex(row, col);
- if (value == 0.d) {
+ final long index = index(row, col);
+ if (value == 0.d && elements.containsKey(index) == false) {
return;
}
- long index = index(row, col);
if (elements.put(index, value, 0.d) == 0.d) {
nnz++;
this.numRows = Math.max(numRows, row + 1);
@@ -190,8 +194,12 @@ public final class DoKMatrix extends AbstractMatrix {
final double value) {
checkIndex(row, col);
- long index = index(row, col);
- double old = elements.put(index, value, 0.d);
+ final long index = index(row, col);
+ if (value == 0.d && elements.containsKey(index) == false) {
+ return 0.d;
+ }
+
+ final double old = elements.put(index, value, 0.d);
if (old == 0.d) {
nnz++;
this.numRows = Math.max(numRows, row + 1);
@@ -309,6 +317,20 @@ public final class DoKMatrix extends AbstractMatrix {
}
}
+ public void eachNonZeroCell(@Nonnull final VectorProcedure procedure) {
+ if (nnz == 0) {
+ return;
+ }
+ final IMapIterator itor = elements.entries();
+ while (itor.next() != -1) {
+ long k = itor.getKey();
+ int row = Primitives.getHigh(k);
+ int col = Primitives.getLow(k);
+ double value = itor.getValue();
+ procedure.apply(row, col, value);
+ }
+ }
+
@Override
public RowMajorMatrix toRowMajorMatrix() {
throw new UnsupportedOperationException("Not yet supported");
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/math/vector/VectorProcedure.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/math/vector/VectorProcedure.java b/core/src/main/java/hivemall/math/vector/VectorProcedure.java
index 266c531..3f3c390 100644
--- a/core/src/main/java/hivemall/math/vector/VectorProcedure.java
+++ b/core/src/main/java/hivemall/math/vector/VectorProcedure.java
@@ -24,6 +24,12 @@ public abstract class VectorProcedure {
public VectorProcedure() {}
+ public void apply(@Nonnegative int i, @Nonnegative int j, float value) {
+ apply(i, j, (double) value);
+ }
+
+ public void apply(@Nonnegative int i, @Nonnegative int j, double value) {}
+
public void apply(@Nonnegative int i, double value) {}
public void apply(@Nonnegative int i, int value) {}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java b/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java
index 141b261..0f9b5fd 100644
--- a/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java
+++ b/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java
@@ -512,9 +512,8 @@ public final class BPRMatrixFactorizationUDTF extends UDTFWithOptions implements
// write training examples in buffer to a temporary file
if (inputBuf.position() > 0) {
writeBuffer(inputBuf, fileIO, lastWritePos);
- } else if (lastWritePos == 0) {
- return; // no training example
}
+
try {
fileIO.flush();
} catch (IOException e) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java b/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java
index 66ec60d..ee549c5 100644
--- a/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java
+++ b/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java
@@ -148,7 +148,7 @@ public abstract class OnlineMatrixFactorizationUDTF extends UDTFWithOptions impl
this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), 1);
if (iterations < 1) {
throw new UDFArgumentException(
- "'-iterations' must be greater than or equals to 1: " + iterations);
+ "'-iterations' must be greater than or equal to 1: " + iterations);
}
conversionCheck = !cl.hasOption("disable_cvtest");
convergenceRate = Primitives.parseDouble(cl.getOptionValue("cv_rate"), convergenceRate);
@@ -239,7 +239,7 @@ public abstract class OnlineMatrixFactorizationUDTF extends UDTFWithOptions impl
}
int item = PrimitiveObjectInspectorUtils.getInt(args[1], itemOI);
if (item < 0) {
- throw new HiveException("Illegal item index: " + user);
+ throw new HiveException("Illegal item index: " + item);
}
double rating = PrimitiveObjectInspectorUtils.getDouble(args[2], ratingOI);
@@ -505,9 +505,8 @@ public abstract class OnlineMatrixFactorizationUDTF extends UDTFWithOptions impl
// write training examples in buffer to a temporary file
if (inputBuf.position() > 0) {
writeBuffer(inputBuf, fileIO, lastWritePos);
- } else if (lastWritePos == 0) {
- return; // no training example
}
+
try {
fileIO.flush();
} catch (IOException e) {