You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/09/15 09:52:50 UTC
incubator-hivemall git commit: Close #115: [HIVEMALL-124][BUGFIX]
Fixed bugs in BinaryResponseMeasure (nDCG, MRR, AP)
Repository: incubator-hivemall
Updated Branches:
refs/heads/master 06f2f8220 -> c2b95783c
Close #115: [HIVEMALL-124][BUGFIX] Fixed bugs in BinaryResponseMeasure (nDCG, MRR, AP)
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/c2b95783
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/c2b95783
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/c2b95783
Branch: refs/heads/master
Commit: c2b95783cf9d6fc1646a48ac928e96152eab98c6
Parents: 06f2f82
Author: Makoto Yui <my...@apache.org>
Authored: Fri Sep 15 18:52:33 2017 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Fri Sep 15 18:52:33 2017 +0900
----------------------------------------------------------------------
.../main/java/hivemall/HivemallConstants.java | 2 +
.../evaluation/BinaryResponsesMeasures.java | 122 ++++++++++++++-----
.../main/java/hivemall/evaluation/MAPUDAF.java | 2 +-
.../main/java/hivemall/evaluation/MRRUDAF.java | 2 +-
.../main/java/hivemall/evaluation/NDCGUDAF.java | 32 +++--
.../hivemall/tools/list/UDAFToOrderedList.java | 2 +-
.../java/hivemall/utils/hadoop/HiveUtils.java | 18 ++-
.../java/hivemall/utils/math/MathUtils.java | 5 +
.../evaluation/BinaryResponsesMeasuresTest.java | 101 +++++++++++++--
docs/gitbook/eval/rank.md | 33 ++---
10 files changed, 250 insertions(+), 69 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c2b95783/core/src/main/java/hivemall/HivemallConstants.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/HivemallConstants.java b/core/src/main/java/hivemall/HivemallConstants.java
index 0eb9feb..67bb228 100644
--- a/core/src/main/java/hivemall/HivemallConstants.java
+++ b/core/src/main/java/hivemall/HivemallConstants.java
@@ -18,6 +18,7 @@
*/
package hivemall;
+
public final class HivemallConstants {
public static final String VERSION = "0.4.2-rc.2";
@@ -35,6 +36,7 @@ public final class HivemallConstants {
public static final String BIGINT_TYPE_NAME = "bigint";
public static final String FLOAT_TYPE_NAME = "float";
public static final String DOUBLE_TYPE_NAME = "double";
+ public static final String DECIMAL_TYPE_NAME = "decimal";
public static final String STRING_TYPE_NAME = "string";
public static final String DATE_TYPE_NAME = "date";
public static final String DATETIME_TYPE_NAME = "datetime";
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c2b95783/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java b/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java
index 81cf075..7c21849 100644
--- a/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java
+++ b/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java
@@ -18,8 +18,12 @@
*/
package hivemall.evaluation;
+import hivemall.utils.lang.Preconditions;
+import hivemall.utils.math.MathUtils;
+
import java.util.List;
+import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
/**
@@ -40,19 +44,25 @@ public final class BinaryResponsesMeasures {
* @return nDCG
*/
public static double nDCG(@Nonnull final List<?> rankedList,
- @Nonnull final List<?> groundTruth, @Nonnull final int recommendSize) {
+ @Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) {
+ Preconditions.checkArgument(recommendSize > 0);
+
double dcg = 0.d;
- double idcg = IDCG(Math.min(recommendSize, groundTruth.size()));
- for (int i = 0, n = recommendSize; i < n; i++) {
+ final int k = Math.min(rankedList.size(), recommendSize);
+ for (int i = 0; i < k; i++) {
Object item_id = rankedList.get(i);
if (!groundTruth.contains(item_id)) {
continue;
}
int rank = i + 1;
- dcg += Math.log(2) / Math.log(rank + 1);
+ dcg += 1.d / MathUtils.log2(rank + 1);
}
+ final double idcg = IDCG(Math.min(groundTruth.size(), k));
+ if (idcg == 0.d) {
+ return 0.d;
+ }
return dcg / idcg;
}
@@ -62,10 +72,12 @@ public final class BinaryResponsesMeasures {
* @param n the number of positive items
* @return ideal DCG
*/
- public static double IDCG(final int n) {
+ public static double IDCG(@Nonnegative final int n) {
+ Preconditions.checkArgument(n >= 0);
+
double idcg = 0.d;
for (int i = 0; i < n; i++) {
- idcg += Math.log(2) / Math.log(i + 2);
+ idcg += 1.d / MathUtils.log2(i + 2);
}
return idcg;
}
@@ -79,8 +91,26 @@ public final class BinaryResponsesMeasures {
* @return Precision
*/
public static double Precision(@Nonnull final List<?> rankedList,
- @Nonnull final List<?> groundTruth, @Nonnull final int recommendSize) {
- return (double) countTruePositive(rankedList, groundTruth, recommendSize) / recommendSize;
+ @Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) {
+ if (rankedList.isEmpty()) {
+ if (groundTruth.isEmpty()) {
+ return 1.d;
+ }
+ return 0.d;
+ }
+
+ Preconditions.checkArgument(recommendSize > 0); // can be zero when groundTruth is empty
+
+ int nTruePositive = 0;
+ final int k = Math.min(rankedList.size(), recommendSize);
+ for (int i = 0; i < k; i++) {
+ Object item_id = rankedList.get(i);
+ if (groundTruth.contains(item_id)) {
+ nTruePositive++;
+ }
+ }
+
+ return ((double) nTruePositive) / k;
}
/**
@@ -92,8 +122,15 @@ public final class BinaryResponsesMeasures {
* @return Recall
*/
public static double Recall(@Nonnull final List<?> rankedList,
- @Nonnull final List<?> groundTruth, @Nonnull final int recommendSize) {
- return (double) countTruePositive(rankedList, groundTruth, recommendSize)
+ @Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) {
+ if (groundTruth.isEmpty()) {
+ if (rankedList.isEmpty()) {
+ return 1.d;
+ }
+ return 0.d;
+ }
+
+ return ((double) TruePositives(rankedList, groundTruth, recommendSize))
/ groundTruth.size();
}
@@ -105,11 +142,14 @@ public final class BinaryResponsesMeasures {
* @param recommendSize top-`recommendSize` items in `rankedList` are recommended
* @return number of true positives
*/
- public static int countTruePositive(final List<?> rankedList, final List<?> groundTruth,
- final int recommendSize) {
+ public static int TruePositives(final List<?> rankedList, final List<?> groundTruth,
+ @Nonnegative final int recommendSize) {
+ Preconditions.checkArgument(recommendSize > 0);
+
int nTruePositive = 0;
- for (int i = 0, n = recommendSize; i < n; i++) {
+ final int k = Math.min(rankedList.size(), recommendSize);
+ for (int i = 0; i < k; i++) {
Object item_id = rankedList.get(i);
if (groundTruth.contains(item_id)) {
nTruePositive++;
@@ -120,48 +160,65 @@ public final class BinaryResponsesMeasures {
}
/**
- * Computes Mean Reciprocal Rank (MRR)
+ * Computes Reciprocal Rank
*
* @param rankedList a list of ranked item IDs (first item is highest-ranked)
* @param groundTruth a collection of positive/correct item IDs
* @param recommendSize top-`recommendSize` items in `rankedList` are recommended
- * @return MRR
+ * @return Reciprocal Rank
+ * @link https://en.wikipedia.org/wiki/Mean_reciprocal_rank
*/
- public static double MRR(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth,
- @Nonnull final int recommendSize) {
- for (int i = 0, n = recommendSize; i < n; i++) {
+ public static double ReciprocalRank(@Nonnull final List<?> rankedList,
+ @Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) {
+ Preconditions.checkArgument(recommendSize > 0);
+
+ final int k = Math.min(rankedList.size(), recommendSize);
+ for (int i = 0; i < k; i++) {
Object item_id = rankedList.get(i);
if (groundTruth.contains(item_id)) {
- return 1.0 / (i + 1.0);
+ return 1.d / (i + 1);
}
}
- return 0.0;
+ return 0.d;
}
/**
- * Computes Mean Average Precision (MAP)
+ * Computes Average Precision (AP)
*
* @param rankedList a list of ranked item IDs (first item is highest-ranked)
* @param groundTruth a collection of positive/correct item IDs
* @param recommendSize top-`recommendSize` items in `rankedList` are recommended
- * @return MAP
+ * @return AveragePrecision
*/
- public static double MAP(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth,
- @Nonnull final int recommendSize) {
+ public static double AveragePrecision(@Nonnull final List<?> rankedList,
+ @Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) {
+ Preconditions.checkArgument(recommendSize > 0);
+
+ if (groundTruth.isEmpty()) {
+ if (rankedList.isEmpty()) {
+ return 1.d;
+ }
+ return 0.d;
+ }
+
int nTruePositive = 0;
- double sumPrecision = 0.0;
+ double sumPrecision = 0.d;
// accumulate precision@1 to @recommendSize
- for (int i = 0, n = recommendSize; i < n; i++) {
+ final int k = Math.min(rankedList.size(), recommendSize);
+ for (int i = 0; i < k; i++) {
Object item_id = rankedList.get(i);
if (groundTruth.contains(item_id)) {
nTruePositive++;
- sumPrecision += nTruePositive / (i + 1.0);
+ sumPrecision += nTruePositive / (i + 1.d);
}
}
- return sumPrecision / groundTruth.size();
+ if (nTruePositive == 0) {
+ return 0.d;
+ }
+ return sumPrecision / nTruePositive;
}
/**
@@ -173,11 +230,14 @@ public final class BinaryResponsesMeasures {
* @return AUC
*/
public static double AUC(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth,
- @Nonnull final int recommendSize) {
+ @Nonnegative final int recommendSize) {
+ Preconditions.checkArgument(recommendSize > 0);
+
int nTruePositive = 0, nCorrectPairs = 0;
// count # of pairs of items that are ranked in the correct order (i.e. TP > FP)
- for (int i = 0, n = recommendSize; i < n; i++) {
+ final int k = Math.min(rankedList.size(), recommendSize);
+ for (int i = 0; i < k; i++) {
Object item_id = rankedList.get(i);
if (groundTruth.contains(item_id)) {
// # of true positives which are ranked higher position than i-th recommended item
@@ -197,7 +257,7 @@ public final class BinaryResponsesMeasures {
}
// AUC can equivalently be calculated by counting the portion of correctly ordered pairs
- return (double) nCorrectPairs / nPairs;
+ return ((double) nCorrectPairs) / nPairs;
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c2b95783/core/src/main/java/hivemall/evaluation/MAPUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/evaluation/MAPUDAF.java b/core/src/main/java/hivemall/evaluation/MAPUDAF.java
index cac6de5..3878684 100644
--- a/core/src/main/java/hivemall/evaluation/MAPUDAF.java
+++ b/core/src/main/java/hivemall/evaluation/MAPUDAF.java
@@ -235,7 +235,7 @@ public final class MAPUDAF extends AbstractGenericUDAFResolver {
void iterate(@Nonnull List<?> recommendList, @Nonnull List<?> truthList,
@Nonnull int recommendSize) {
- sum += BinaryResponsesMeasures.MAP(recommendList, truthList, recommendSize);
+ sum += BinaryResponsesMeasures.AveragePrecision(recommendList, truthList, recommendSize);
count++;
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c2b95783/core/src/main/java/hivemall/evaluation/MRRUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/evaluation/MRRUDAF.java b/core/src/main/java/hivemall/evaluation/MRRUDAF.java
index 41a236d..f5aba3b 100644
--- a/core/src/main/java/hivemall/evaluation/MRRUDAF.java
+++ b/core/src/main/java/hivemall/evaluation/MRRUDAF.java
@@ -235,7 +235,7 @@ public final class MRRUDAF extends AbstractGenericUDAFResolver {
void iterate(@Nonnull List<?> recommendList, @Nonnull List<?> truthList,
@Nonnull int recommendSize) {
- sum += BinaryResponsesMeasures.MRR(recommendList, truthList, recommendSize);
+ sum += BinaryResponsesMeasures.ReciprocalRank(recommendList, truthList, recommendSize);
count++;
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c2b95783/core/src/main/java/hivemall/evaluation/NDCGUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/evaluation/NDCGUDAF.java b/core/src/main/java/hivemall/evaluation/NDCGUDAF.java
index f50d27a..f1ba832 100644
--- a/core/src/main/java/hivemall/evaluation/NDCGUDAF.java
+++ b/core/src/main/java/hivemall/evaluation/NDCGUDAF.java
@@ -18,6 +18,8 @@
*/
package hivemall.evaluation;
+import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
+import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableLongObjectInspector;
import hivemall.utils.hadoop.HiveUtils;
import java.util.ArrayList;
@@ -38,10 +40,11 @@ import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableDoubleObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
@@ -120,8 +123,8 @@ public final class NDCGUDAF extends AbstractGenericUDAFResolver {
}
private static StructObjectInspector internalMergeOI() {
- ArrayList<String> fieldNames = new ArrayList<String>();
- ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+ List<String> fieldNames = new ArrayList<>();
+ List<ObjectInspector> fieldOIs = new ArrayList<>();
fieldNames.add("sum");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
@@ -180,20 +183,31 @@ public final class NDCGUDAF extends AbstractGenericUDAFResolver {
StructObjectInspector sOI = (StructObjectInspector) recommendListOI.getListElementObjectInspector();
List<?> fieldRefList = sOI.getAllStructFieldRefs();
StructField relScoreField = (StructField) fieldRefList.get(0);
- WritableDoubleObjectInspector relScoreFieldOI = (WritableDoubleObjectInspector) relScoreField.getFieldObjectInspector();
+ PrimitiveObjectInspector relScoreFieldOI = HiveUtils.asDoubleCompatibleOI(relScoreField.getFieldObjectInspector());
for (int i = 0, n = recommendList.size(); i < n; i++) {
Object structObj = recommendList.get(i);
List<Object> fieldList = sOI.getStructFieldsDataAsList(structObj);
- double relScore = (double) relScoreFieldOI.get(fieldList.get(0));
+ Object field0 = fieldList.get(0);
+ if (field0 == null) {
+ throw new UDFArgumentException("Field 0 of a struct field is null: "
+ + fieldList);
+ }
+ double relScore = PrimitiveObjectInspectorUtils.getDouble(field0,
+ relScoreFieldOI);
recommendRelScoreList.add(relScore);
}
// Create a ordered list of relevance scores for truth items
List<Double> truthRelScoreList = new ArrayList<Double>();
- WritableDoubleObjectInspector truthRelScoreOI = (WritableDoubleObjectInspector) truthListOI.getListElementObjectInspector();
+ PrimitiveObjectInspector truthRelScoreOI = HiveUtils.asDoubleCompatibleOI(truthListOI.getListElementObjectInspector());
for (int i = 0, n = truthList.size(); i < n; i++) {
Object relScoreObj = truthList.get(i);
- double relScore = (double) truthRelScoreOI.get(relScoreObj);
+ if (relScoreObj == null) {
+ throw new UDFArgumentException("Found null in the ground truth: "
+ + truthList);
+ }
+ double relScore = PrimitiveObjectInspectorUtils.getDouble(relScoreObj,
+ truthRelScoreOI);
truthRelScoreList.add(relScore);
}
@@ -224,8 +238,8 @@ public final class NDCGUDAF extends AbstractGenericUDAFResolver {
Object sumObj = internalMergeOI.getStructFieldData(partial, sumField);
Object countObj = internalMergeOI.getStructFieldData(partial, countField);
- double sum = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(sumObj);
- long count = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(countObj);
+ double sum = writableDoubleObjectInspector.get(sumObj);
+ long count = writableLongObjectInspector.get(countObj);
NDCGAggregationBuffer myAggr = (NDCGAggregationBuffer) agg;
myAggr.merge(sum, count);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c2b95783/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java b/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java
index e88a16c..52c521c 100644
--- a/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java
+++ b/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java
@@ -207,7 +207,7 @@ public final class UDAFToOrderedList extends AbstractGenericUDAFResolver {
|| (argOIs.length == 3 && HiveUtils.isConstString(argOIs[2]));
if (sortByKey) {
- this.valueOI = HiveUtils.asPrimitiveObjectInspector(argOIs[0]);
+ this.valueOI = argOIs[0];
this.keyOI = HiveUtils.asPrimitiveObjectInspector(argOIs[1]);
} else {
// sort values by value itself
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c2b95783/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
index 8fba349..b8b344c 100644
--- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
@@ -21,6 +21,7 @@ package hivemall.utils.hadoop;
import static hivemall.HivemallConstants.BIGINT_TYPE_NAME;
import static hivemall.HivemallConstants.BINARY_TYPE_NAME;
import static hivemall.HivemallConstants.BOOLEAN_TYPE_NAME;
+import static hivemall.HivemallConstants.DECIMAL_TYPE_NAME;
import static hivemall.HivemallConstants.DOUBLE_TYPE_NAME;
import static hivemall.HivemallConstants.FLOAT_TYPE_NAME;
import static hivemall.HivemallConstants.INT_TYPE_NAME;
@@ -47,6 +48,7 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject;
import org.apache.hadoop.hive.serde2.SerDeException;
import org.apache.hadoop.hive.serde2.io.ByteWritable;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable;
import org.apache.hadoop.hive.serde2.io.ShortWritable;
import org.apache.hadoop.hive.serde2.lazy.ByteArrayRef;
import org.apache.hadoop.hive.serde2.lazy.LazyDouble;
@@ -265,6 +267,7 @@ public final class HiveUtils {
case LONG:
case FLOAT:
case DOUBLE:
+ case DECIMAL:
case BYTE:
//case TIMESTAMP:
return true;
@@ -357,6 +360,7 @@ public final class HiveUtils {
case LONG:
case FLOAT:
case DOUBLE:
+ case DECIMAL:
return true;
default:
return false;
@@ -404,6 +408,7 @@ public final class HiveUtils {
switch (((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory()) {
case DOUBLE:
case FLOAT:
+ case DECIMAL:
return true;
default:
return false;
@@ -630,6 +635,9 @@ public final class HiveUtils {
} else if (TINYINT_TYPE_NAME.equals(typeName)) {
ByteWritable v = getConstValue(numberOI);
return v.get();
+ } else if (DECIMAL_TYPE_NAME.equals(typeName)) {
+ HiveDecimalWritable v = getConstValue(numberOI);
+ return v.getHiveDecimal().floatValue();
}
throw new UDFArgumentException("Unexpected argument type to cast as double: "
+ TypeInfoUtils.getTypeInfoFromObjectInspector(numberOI));
@@ -656,6 +664,9 @@ public final class HiveUtils {
} else if (TINYINT_TYPE_NAME.equals(typeName)) {
ByteWritable v = getConstValue(numberOI);
return v.get();
+ } else if (DECIMAL_TYPE_NAME.equals(typeName)) {
+ HiveDecimalWritable v = getConstValue(numberOI);
+ return v.getHiveDecimal().doubleValue();
}
throw new UDFArgumentException("Unexpected argument type to cast as double: "
+ TypeInfoUtils.getTypeInfoFromObjectInspector(numberOI));
@@ -923,10 +934,10 @@ public final class HiveUtils {
case LONG:
case FLOAT:
case DOUBLE:
+ case DECIMAL:
case BOOLEAN:
case BYTE:
case STRING:
- case DECIMAL:
break;
default:
throw new UDFArgumentTypeException(0, "Unxpected type '" + argOI.getTypeName()
@@ -951,9 +962,9 @@ public final class HiveUtils {
case BOOLEAN:
case FLOAT:
case DOUBLE:
+ case DECIMAL:
case STRING:
case TIMESTAMP:
- case DECIMAL:
break;
default:
throw new UDFArgumentTypeException(0, "Unxpected type '" + argOI.getTypeName()
@@ -998,6 +1009,7 @@ public final class HiveUtils {
case LONG:
case FLOAT:
case DOUBLE:
+ case DECIMAL:
case STRING:
case TIMESTAMP:
break;
@@ -1020,6 +1032,7 @@ public final class HiveUtils {
switch (oi.getPrimitiveCategory()) {
case FLOAT:
case DOUBLE:
+ case DECIMAL:
break;
default:
throw new UDFArgumentTypeException(0,
@@ -1044,6 +1057,7 @@ public final class HiveUtils {
case LONG:
case FLOAT:
case DOUBLE:
+ case DECIMAL:
break;
default:
throw new UDFArgumentTypeException(0,
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c2b95783/core/src/main/java/hivemall/utils/math/MathUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java b/core/src/main/java/hivemall/utils/math/MathUtils.java
index 6162adb..ee533dc 100644
--- a/core/src/main/java/hivemall/utils/math/MathUtils.java
+++ b/core/src/main/java/hivemall/utils/math/MathUtils.java
@@ -43,6 +43,7 @@ import javax.annotation.Nullable;
import org.apache.commons.math3.special.Gamma;
public final class MathUtils {
+ private static final double LOG2 = Math.log(2);
private MathUtils() {}
@@ -246,6 +247,10 @@ public final class MathUtils {
return Math.log(n) / Math.log(base);
}
+ public static double log2(final double n) {
+ return Math.log(n) / LOG2;
+ }
+
public static int floorDiv(final int x, final int y) {
int r = x / y;
// if the signs are different and modulo not zero, round down
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c2b95783/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java b/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java
index 9f8a04e..5e8f253 100644
--- a/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java
+++ b/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java
@@ -18,8 +18,8 @@
*/
package hivemall.evaluation;
-import java.util.Collections;
import java.util.Arrays;
+import java.util.Collections;
import java.util.List;
import org.junit.Assert;
@@ -40,6 +40,18 @@ public class BinaryResponsesMeasuresTest {
}
@Test
+ public void testNDCG2() {
+ List<Integer> rankedList = Arrays.asList(3, 2, 1, 6);
+ List<Integer> groundTruth = Arrays.asList(1);
+
+ double actual = BinaryResponsesMeasures.nDCG(rankedList, groundTruth, 2);
+ Assert.assertEquals(0.d, actual, 0.0001d);
+
+ actual = BinaryResponsesMeasures.nDCG(rankedList, groundTruth, 3);
+ Assert.assertEquals(0.5d, actual, 0.0001d);
+ }
+
+ @Test
public void testRecall() {
List<Integer> rankedList = Arrays.asList(1, 3, 2, 6);
List<Integer> groundTruth = Arrays.asList(1, 2, 4);
@@ -52,6 +64,16 @@ public class BinaryResponsesMeasuresTest {
}
@Test
+ public void testRecallEmpty() {
+ Assert.assertEquals(1.d,
+ BinaryResponsesMeasures.Recall(Collections.emptyList(), Collections.emptyList(), 2),
+ 0.d);
+
+ Assert.assertEquals(0.d,
+ BinaryResponsesMeasures.Recall(Arrays.asList(1, 3, 2), Collections.emptyList(), 2), 0.d);
+ }
+
+ @Test
public void testPrecision() {
List<Integer> rankedList = Arrays.asList(1, 3, 2, 6);
List<Integer> groundTruth = Arrays.asList(1, 2, 4);
@@ -65,32 +87,91 @@ public class BinaryResponsesMeasuresTest {
}
@Test
- public void testMRR() {
+ public void testPrecisionEmpty() {
+ Assert.assertEquals(1.d,
+ BinaryResponsesMeasures.Precision(Collections.emptyList(), Collections.emptyList(), 2),
+ 0.d);
+
+ Assert.assertEquals(0.d,
+ BinaryResponsesMeasures.Precision(Arrays.asList(1, 3, 2), Collections.emptyList(), 2),
+ 0.d);
+ }
+
+ @Test
+ public void testRR() {
List<Integer> rankedList = Arrays.asList(1, 3, 2, 6);
List<Integer> groundTruth = Arrays.asList(1, 2, 4);
- double actual = BinaryResponsesMeasures.MRR(rankedList, groundTruth, rankedList.size());
+ double actual = BinaryResponsesMeasures.ReciprocalRank(rankedList, groundTruth,
+ rankedList.size());
Assert.assertEquals(1.0d, actual, 0.0001d);
Collections.reverse(rankedList);
- actual = BinaryResponsesMeasures.MRR(rankedList, groundTruth, rankedList.size());
+ actual = BinaryResponsesMeasures.ReciprocalRank(rankedList, groundTruth, rankedList.size());
Assert.assertEquals(0.5d, actual, 0.0001d);
- actual = BinaryResponsesMeasures.MRR(rankedList, groundTruth, 1);
+ actual = BinaryResponsesMeasures.ReciprocalRank(rankedList, groundTruth, 1);
Assert.assertEquals(0.0d, actual, 0.0001d);
}
@Test
- public void testMAP() {
+ public void testAP() {
List<Integer> rankedList = Arrays.asList(1, 3, 2, 6);
List<Integer> groundTruth = Arrays.asList(1, 2, 4);
- double actual = BinaryResponsesMeasures.MAP(rankedList, groundTruth, rankedList.size());
- Assert.assertEquals(0.5555555555555555d, actual, 0.0001d);
+ double actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth,
+ rankedList.size());
+ Assert.assertEquals(1.0 / 2.0 * (1.0 / 1.0 + 2.0 / 3.0), actual, 0.0001d);
+
+ actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 4);
+ Assert.assertEquals(1.0 / 2.0 * (1.0 / 1.0 + 2.0 / 3.0), actual, 0.0001d);
- actual = BinaryResponsesMeasures.MAP(rankedList, groundTruth, 2);
- Assert.assertEquals(0.3333333333333333d, actual, 0.0001d);
+ actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 3);
+ Assert.assertEquals(1.0 / 2.0 * (1.0 / 1.0 + 2.0 / 3.0), actual, 0.0001d);
+
+ actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 2);
+ Assert.assertEquals(1.0 / 1.0 * (1.0 / 1.0), actual, 0.0001d);
+
+ rankedList = Arrays.asList(3, 1, 2, 6);
+ actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 2);
+ Assert.assertEquals(1.0 / 1.0 * (1.0 / 2.0), actual, 0.0001d);
+
+ groundTruth = Arrays.asList(1, 2, 3);
+ actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 2);
+ Assert.assertEquals(1.0 / 2.0 * (1.0 / 1.0 + 2.0 / 2.0), actual, 0.0001d);
+
+ rankedList = Arrays.asList(3, 1);
+ groundTruth = Arrays.asList(1, 2);
+ actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 2);
+ Assert.assertEquals(1.0 / 1.0 * (1.0 / 2.0), actual, 0.0001d);
+ }
+
+ @Test
+ public void testAPString() {
+ List<String> rankedList = Arrays.asList("a", "b", "c", "d", "e", "f", "g");
+ List<String> groundTruth = Arrays.asList("a", "x", "x", "d", "x", "x");
+
+ double actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 6);
+ Assert.assertEquals(0.75d, actual, 0.0001d);
+ }
+
+ @Test
+ public void testAPString10() {
+ List<String> rankedList = Arrays.asList("a", "b", "c", "d", "e", "f", "g", "h", "i", "j");
+ List<String> groundTruth = Arrays.asList("a", "x", "c", "x", "e", "f");
+
+ double actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 10);
+ Assert.assertEquals(1.0 / 4.0 * (1.0 / 1.0 + 2.0 / 3.0 + 3.0 / 5.0 + 4.0 / 6.0), actual,
+ 0.0001d);
+
+ actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 5);
+ Assert.assertEquals(1.0 / 3.0 * (1.0 / 1.0 + 2.0 / 3.0 + 3.0 / 5.0), actual, 0.0001d);
+
+ groundTruth = Arrays.asList("a", "x", "c", "x", "e", "f", "x", "x", "x", "x");
+ actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 10);
+ Assert.assertEquals(1.0 / 4.0 * (1.0 / 1.0 + 2.0 / 3.0 + 3.0 / 5.0 + 4.0 / 6.0), actual,
+ 0.0001d);
}
@Test
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c2b95783/docs/gitbook/eval/rank.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/eval/rank.md b/docs/gitbook/eval/rank.md
index 207418e..30d82e5 100644
--- a/docs/gitbook/eval/rank.md
+++ b/docs/gitbook/eval/rank.md
@@ -83,7 +83,8 @@ with truth as (
rec as (
select
userid,
- map_values(to_ordered_map(score, itemid, true)) as rec,
+ -- map_values(to_ordered_map(score, itemid, true)) as rec,
+ to_ordered_list(itemid, score, '-reverse') as rec,
cast(count(itemid) as int) as max_k
from dummy_rec
group by userid
@@ -222,7 +223,7 @@ While the binary response setting simply considers positive-only ranked list of
Unlike separated `dummy_truth` and `dummy_rec` table in the binary setting, we assume the following single table named `dummy_recrel` which contains item-$$\mathrm{rel}_n$$ pairs:
-| userid | itemid | score<br/>(predicted) | rel<br/>(expected) |
+| userid | itemid | score<br/>(predicted) | relscore<br/>(expected) |
| :-: | :-: | :-: | :-: |
| 1 | 1 | 10.0 | 5.0 |
| 1 | 3 | 8.0 | 2.0 |
@@ -244,27 +245,31 @@ The function `ndcg()` can take non-binary `truth` values as the second argument:
```sql
with truth as (
- select userid, map_keys(to_ordered_map(relscore, itemid, true)) as truth
- from dummy_recrel
- group by userid
+ select
+ userid,
+ to_ordered_list(relscore, '-reverse') as truth
+ from
+ dummy_recrel
+ group by
+ userid
),
rec as (
select
userid,
- map_values (
- to_ordered_map(score, struct(relscore, itemid), true)
- ) as rec,
- cast(count(itemid) as int) as max_k
- from dummy_recrel
- group by userid
+ to_ordered_list(struct(relscore, itemid), score, "-reverse") as rec,
+ count(itemid) as max_k
+ from
+ dummy_recrel
+ group by
+ userid
)
select
-- top-2 recommendation
ndcg(t1.rec, t2.truth, 2), -- => 0.8128912838590544
-
-- top-3 recommendation
ndcg(t1.rec, t2.truth, 3) -- => 0.9187707805346093
-from rec t1
-join truth t2 on (t1.userid = t2.userid)
+from
+ rec t1
+ join truth t2 on (t1.userid = t2.userid)
;
```