You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2018/11/02 10:35:24 UTC
incubator-hivemall git commit: [HIVEMALL-196] Support BM25 scoring
Repository: incubator-hivemall
Updated Branches:
refs/heads/master 4f795cb9a -> ce70aa482
[HIVEMALL-196] Support BM25 scoring
## What changes were proposed in this pull request?
Adding scoring function Okapi BM25 as a UDF
## What type of PR is it?
Feature
## What is the Jira issue?
https://issues.apache.org/jira/projects/HIVEMALL/issues/HIVEMALL-196
## How was this patch tested?
1. Unit testing
2. Manual testing on Hive
## How to use this feature?
This new `okapi_bm25` function requires 5 mandatory arguments and 2 optional hyperparameters:
1. raw frequency count of a term in a given document
2. length of the given document
3. average length of a document in the corpus
4. number of documents in the corpus
5. number of documents containing the term, i.e. document frequency
6. (*optional*) k1 - a smoothing hyperparameter
7. (*optional*) b - a smoothing hyperparameter
### Step 1: Count frequency of terms
```sql
create or replace view frequency
as
select
docid,
word,
count(*) as freq
from
test_corpus_exploded
group by
docid,
word
;
```
### Step 2: Calculate document lengths
```sql
create or replace view doc_len
as
select
docid, count(1) as cnt
from
test_corpus_exploded
group by
docid
;
```
### Step 3: Calculate document frequency
```sql
create or replace view document_frequency
as
select
word,
count(distinct docid) docs
from
test_corpus_exploded
group by
word
;
```
### Step 4: Set number of documents
```sql
set hivevar:n_docs=3;
```
### Step 5: Use `okapi_bm25`
```sql
create or replace view bm25
as
with tmp as (
select avg(cnt) as avgdl from doc_len
)
select
f.docid,
f.word,
okapi_bm25(
CAST(f.freq AS INT),
dl.cnt,
CAST(tmp.avgdl AS DOUBLE),
${n_docs},
df.docs,
'-k1 1.5 -b 0.75'
) as score
from frequency f
JOIN document_frequency df ON (f.word = df.word)
JOIN doc_len dl ON (f.docid = dl.docid)
CROSS JOIN tmp
ORDER BY
score desc;
```
## Checklist
(Please remove this section if not needed; check `x` for YES, blank for NO)
- [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit?
- [x] Did you run system tests on Hive (or Spark)?
Author: Jackson Huang <hu...@treasure-data.com>
Author: Makoto Yui <my...@apache.org>
Closes #163 from jaxony/feature/bm25.
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/ce70aa48
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/ce70aa48
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/ce70aa48
Branch: refs/heads/master
Commit: ce70aa482c766f4d45160c850d28794c39028059
Parents: 4f795cb
Author: Jackson Huang <hu...@treasure-data.com>
Authored: Fri Nov 2 19:35:13 2018 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Fri Nov 2 19:35:13 2018 +0900
----------------------------------------------------------------------
core/src/main/java/hivemall/UDFWithOptions.java | 24 +++
.../java/hivemall/ftvec/text/OkapiBM25UDF.java | 172 ++++++++++++++++
.../java/hivemall/utils/lang/Primitives.java | 4 +
.../hivemall/ftvec/text/OkapiBM25UDFTest.java | 193 ++++++++++++++++++
docs/gitbook/SUMMARY.md | 1 +
docs/gitbook/ft_engineering/bm25.md | 197 +++++++++++++++++++
docs/gitbook/misc/funcs.md | 2 +
resources/ddl/define-all-as-permanent.hive | 3 +
resources/ddl/define-all.hive | 4 +
resources/ddl/define-all.spark | 4 +
10 files changed, 604 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ce70aa48/core/src/main/java/hivemall/UDFWithOptions.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/UDFWithOptions.java b/core/src/main/java/hivemall/UDFWithOptions.java
index 9908cd9..04d6fdc 100644
--- a/core/src/main/java/hivemall/UDFWithOptions.java
+++ b/core/src/main/java/hivemall/UDFWithOptions.java
@@ -112,6 +112,30 @@ public abstract class UDFWithOptions extends GenericUDF {
return cl;
}
+ /**
+ * Raise {@link UDFArgumentException} if the given condition is false.
+ *
+ * @throws UDFArgumentException
+ */
+ protected static void assumeTrue(final boolean condition, @Nonnull final String errMsg)
+ throws UDFArgumentException {
+ if (!condition) {
+ throw new UDFArgumentException(errMsg);
+ }
+ }
+
+ /**
+ * Raise {@link UDFArgumentException} if the given condition is true.
+ *
+ * @throws UDFArgumentException
+ */
+ protected static void assumeFalse(final boolean condition, @Nonnull final String errMsg)
+ throws UDFArgumentException {
+ if (condition) {
+ throw new UDFArgumentException(errMsg);
+ }
+ }
+
@Nonnull
protected abstract CommandLine processOptions(@Nonnull String optionValue)
throws UDFArgumentException;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ce70aa48/core/src/main/java/hivemall/ftvec/text/OkapiBM25UDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/text/OkapiBM25UDF.java b/core/src/main/java/hivemall/ftvec/text/OkapiBM25UDF.java
new file mode 100644
index 0000000..cd36d6f
--- /dev/null
+++ b/core/src/main/java/hivemall/ftvec/text/OkapiBM25UDF.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.ftvec.text;
+
+import hivemall.UDFWithOptions;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Primitives;
+import hivemall.utils.lang.StringUtils;
+
+import javax.annotation.Nonnull;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.UDFType;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+
+@Description(name = "bm25",
+ value = "_FUNC_(double termFrequency, int docLength, double avgDocLength, int numDocs, int numDocsWithTerm [, const string options]) - Return an Okapi BM25 score in double")
+@UDFType(deterministic = true, stateful = false)
+public final class OkapiBM25UDF extends UDFWithOptions {
+
+ private double k1 = 1.2d;
+ private double b = 0.75d;
+
+ // BM25+ https://en.wikipedia.org/wiki/Okapi_BM25#General_references
+ private double delta = 0.d;
+
+ // epsilon in https://en.wikipedia.org/wiki/Okapi_BM25#The_ranking_function
+ private double minIDF = 1e-8;
+
+ private PrimitiveObjectInspector frequencyOI;
+ private PrimitiveObjectInspector docLengthOI;
+ private PrimitiveObjectInspector averageDocLengthOI;
+ private PrimitiveObjectInspector numDocsOI;
+ private PrimitiveObjectInspector numDocsWithTermOI;
+
+ @Nonnull
+ private final DoubleWritable result = new DoubleWritable();
+
+ public OkapiBM25UDF() {}
+
+ @Override
+ protected Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("k1", true,
+ "Hyperparameter with type double, usually in range 1.2 and 2.0 [default: 1.2]");
+ opts.addOption("b", true,
+ "Hyperparameter with type double in range 0.0 and 1.0 [default: 0.75]");
+ opts.addOption("d", "delta", true, "Hyperparameter delta of BM25+ [default: 0.0]");
+ opts.addOption("min_idf", "epsilon", true, "Hyperparameter delta of BM25+ [default: 1e-8]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(@Nonnull String opts) throws UDFArgumentException {
+ CommandLine cl = parseOptions(opts);
+
+ this.k1 = Primitives.parseDouble(cl.getOptionValue("k1"), k1);
+
+ if (Primitives.isFinite(k1) == false || k1 < 0.0) {
+ throw new UDFArgumentException("k1 must be a non-negative finite value: " + k1);
+ }
+
+ this.b = Primitives.parseDouble(cl.getOptionValue("b"), b);
+ if (Double.isNaN(b) || b < 0.0 || b > 1.0) {
+ throw new UDFArgumentException(
+ "b1 hyperparameter must be in the range [0.0, 1.0]: " + b);
+ }
+
+ this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), delta);
+ if (Primitives.isFinite(delta) == false) {
+ throw new UDFArgumentException("Delta must be a finite value: " + delta);
+ }
+
+ this.minIDF = Primitives.parseDouble(cl.getOptionValue("min_idf"), minIDF);
+ if (minIDF < 0.d) {
+ throw new UDFArgumentException("min_idf must not be negative value: " + minIDF);
+ }
+
+ return cl;
+ }
+
+ @Override
+ public ObjectInspector initialize(@Nonnull ObjectInspector[] argOIs)
+ throws UDFArgumentException {
+ final int numArgOIs = argOIs.length;
+ if (numArgOIs < 5) {
+ throw new UDFArgumentException("argOIs.length must be greater than or equal to 5");
+ } else if (numArgOIs == 6) {
+ String opts = HiveUtils.getConstString(argOIs[5]);
+ processOptions(opts);
+ }
+
+ this.frequencyOI = HiveUtils.asDoubleCompatibleOI(argOIs[0]);
+ this.docLengthOI = HiveUtils.asIntegerOI(argOIs[1]);
+ this.averageDocLengthOI = HiveUtils.asDoubleCompatibleOI(argOIs[2]);
+ this.numDocsOI = HiveUtils.asIntegerOI(argOIs[3]);
+ this.numDocsWithTermOI = HiveUtils.asIntegerOI(argOIs[4]);
+
+ return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
+ }
+
+ @Override
+ public DoubleWritable evaluate(@Nonnull DeferredObject[] arguments) throws HiveException {
+ Object arg0 = arguments[0].get();
+ Object arg1 = arguments[1].get();
+ Object arg2 = arguments[2].get();
+ Object arg3 = arguments[3].get();
+ Object arg4 = arguments[4].get();
+
+ if (arg0 == null || arg1 == null || arg2 == null || arg3 == null || arg4 == null) {
+ throw new UDFArgumentException("Required arguments cannot be null");
+ }
+
+ double frequency = PrimitiveObjectInspectorUtils.getDouble(arg0, frequencyOI);
+ int docLength = PrimitiveObjectInspectorUtils.getInt(arg1, docLengthOI);
+ double averageDocLength = PrimitiveObjectInspectorUtils.getDouble(arg2, averageDocLengthOI);
+ int numDocs = PrimitiveObjectInspectorUtils.getInt(arg3, numDocsOI);
+ int numDocsWithTerm = PrimitiveObjectInspectorUtils.getInt(arg4, numDocsWithTermOI);
+
+ assumeFalse(frequency < 0, "#frequency must be positive");
+ assumeFalse(docLength < 1, "#docLength must be greater than or equal to 1");
+ assumeFalse(averageDocLength <= 0.0, "#averageDocLength must be positive");
+ assumeFalse(numDocs < 1, "#numDocs must be greater than or equal to 1");
+ assumeFalse(numDocsWithTerm < 1, "#numDocsWithTerm must be greater than or equal to 1");
+
+ double v = bm25(frequency, docLength, averageDocLength, numDocs, numDocsWithTerm);
+ result.set(v);
+ return result;
+ }
+
+ private double bm25(final double tf, final int docLength, final double averageDocLength,
+ final int numDocs, final int numDocsWithTerm) {
+ double numerator = tf * (k1 + 1);
+ double denominator = tf + k1 * (1 - b + b * docLength / averageDocLength);
+ double idf = Math.max(minIDF, idf(numDocs, numDocsWithTerm));
+ return idf * (numerator / denominator + delta);
+ }
+
+ private static double idf(final int numDocs, final int numDocsWithTerm) {
+ return Math.log10(1.0d + (numDocs - numDocsWithTerm + 0.5d) / (numDocsWithTerm + 0.5d));
+ }
+
+ @Override
+ public String getDisplayString(String[] children) {
+ return "bm25(" + StringUtils.join(children, ',') + ")";
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ce70aa48/core/src/main/java/hivemall/utils/lang/Primitives.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/lang/Primitives.java b/core/src/main/java/hivemall/utils/lang/Primitives.java
index ab3be9a..1c05102 100644
--- a/core/src/main/java/hivemall/utils/lang/Primitives.java
+++ b/core/src/main/java/hivemall/utils/lang/Primitives.java
@@ -78,6 +78,10 @@ public final class Primitives {
return v.doubleValue();
}
+ public static boolean isFinite(final double value) {
+ return Double.NEGATIVE_INFINITY < value && value < Double.POSITIVE_INFINITY;
+ }
+
public static int compare(final int x, final int y) {
return (x < y) ? -1 : ((x == y) ? 0 : 1);
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ce70aa48/core/src/test/java/hivemall/ftvec/text/OkapiBM25UDFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/ftvec/text/OkapiBM25UDFTest.java b/core/src/test/java/hivemall/ftvec/text/OkapiBM25UDFTest.java
new file mode 100644
index 0000000..8ddba23
--- /dev/null
+++ b/core/src/test/java/hivemall/ftvec/text/OkapiBM25UDFTest.java
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.ftvec.text;
+
+import static org.junit.Assert.assertEquals;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.hadoop.WritableUtils;
+
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.junit.Before;
+import org.junit.Test;
+
+public class OkapiBM25UDFTest {
+
+ private static final double EPSILON = 1e-8;
+ private static final GenericUDF.DeferredJavaObject VALID_TERM_FREQ =
+ new GenericUDF.DeferredJavaObject(new Integer(3));
+ private static final GenericUDF.DeferredJavaObject VALID_DOC_LEN =
+ new GenericUDF.DeferredJavaObject(new Integer(9));
+ private static final GenericUDF.DeferredJavaObject VALID_AVG_DOC_LEN =
+ new GenericUDF.DeferredJavaObject(new Double(10.35));
+ private static final GenericUDF.DeferredJavaObject VALID_NUM_DOCS =
+ new GenericUDF.DeferredJavaObject(new Integer(20));
+ private static final GenericUDF.DeferredJavaObject VALID_NUM_DOCS_WITH_TERM =
+ new GenericUDF.DeferredJavaObject(new Integer(5));
+
+ private OkapiBM25UDF udf = null;
+
+
+ @Before
+ public void init() throws Exception {
+ udf = new OkapiBM25UDF();
+ }
+
+ @Test
+ public void testEvaluate() throws Exception {
+
+ initializeUDFWithoutOptions();
+
+ GenericUDF.DeferredObject[] args = new GenericUDF.DeferredObject[] {VALID_TERM_FREQ,
+ VALID_DOC_LEN, VALID_AVG_DOC_LEN, VALID_NUM_DOCS, VALID_NUM_DOCS_WITH_TERM};
+
+ DoubleWritable expected = WritableUtils.val(0.940637195691);
+ DoubleWritable actual = udf.evaluate(args);
+ assertEquals(expected.get(), actual.get(), EPSILON);
+ }
+
+ @Test
+ public void testEvaluateWithCustomK1() throws Exception {
+
+ udf.initialize(
+ new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector,
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector,
+ PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector,
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector,
+ HiveUtils.getConstStringObjectInspector("-k1 1.5")});
+
+ GenericUDF.DeferredObject[] args = new GenericUDF.DeferredObject[] {VALID_TERM_FREQ,
+ VALID_DOC_LEN, VALID_AVG_DOC_LEN, VALID_NUM_DOCS, VALID_NUM_DOCS_WITH_TERM};
+
+ DoubleWritable expected = WritableUtils.val(1.00244958206);
+ DoubleWritable actual = udf.evaluate(args);
+ assertEquals(expected.get(), actual.get(), EPSILON);
+ }
+
+ @Test
+ public void testEvaluateWithCustomB() throws Exception {
+
+ udf.initialize(
+ new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector,
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector,
+ PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector,
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector,
+ HiveUtils.getConstStringObjectInspector("-b 0.8")});
+
+ GenericUDF.DeferredObject[] args = new GenericUDF.DeferredObject[] {VALID_TERM_FREQ,
+ VALID_DOC_LEN, VALID_AVG_DOC_LEN, VALID_NUM_DOCS, VALID_NUM_DOCS_WITH_TERM};
+
+ DoubleWritable expected = WritableUtils.val(0.942443797219);
+ DoubleWritable actual = udf.evaluate(args);
+ assertEquals(expected.get(), actual.get(), EPSILON);
+ }
+
+ @Test(expected = HiveException.class)
+ public void testInputArgIsNull() throws Exception {
+
+ initializeUDFWithoutOptions();
+
+ GenericUDF.DeferredObject[] args =
+ new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(null),
+ VALID_DOC_LEN, VALID_AVG_DOC_LEN, VALID_NUM_DOCS, VALID_NUM_DOCS_WITH_TERM};
+
+ udf.evaluate(args);
+ }
+
+ @Test(expected = HiveException.class)
+ public void testTermFrequencyIsNegative() throws Exception {
+ initializeUDFWithoutOptions();
+
+ GenericUDF.DeferredObject[] args =
+ new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(new Integer(-1)),
+ VALID_DOC_LEN, VALID_AVG_DOC_LEN, VALID_NUM_DOCS, VALID_NUM_DOCS_WITH_TERM};
+
+ udf.evaluate(args);
+ }
+
+ @Test(expected = HiveException.class)
+ public void testDocLengthIsLessThanOne() throws Exception {
+ initializeUDFWithoutOptions();
+
+ GenericUDF.DeferredObject[] args = new GenericUDF.DeferredObject[] {VALID_TERM_FREQ,
+ new GenericUDF.DeferredJavaObject(new Integer(0)), VALID_AVG_DOC_LEN,
+ VALID_NUM_DOCS, VALID_NUM_DOCS_WITH_TERM};
+
+ udf.evaluate(args);
+ }
+
+ @Test(expected = HiveException.class)
+ public void testAvgDocLengthIsNegative() throws Exception {
+ initializeUDFWithoutOptions();
+
+ GenericUDF.DeferredObject[] args = new GenericUDF.DeferredObject[] {VALID_TERM_FREQ,
+ VALID_DOC_LEN, new GenericUDF.DeferredJavaObject(new Double(-10)), VALID_NUM_DOCS,
+ VALID_NUM_DOCS_WITH_TERM};
+
+ udf.evaluate(args);
+ }
+
+ @Test(expected = HiveException.class)
+ public void testAvgDocLengthIsZero() throws Exception {
+ initializeUDFWithoutOptions();
+
+ GenericUDF.DeferredObject[] args = new GenericUDF.DeferredObject[] {VALID_TERM_FREQ,
+ VALID_DOC_LEN, new GenericUDF.DeferredJavaObject(new Double(0.0)), VALID_NUM_DOCS,
+ VALID_NUM_DOCS_WITH_TERM};
+
+ udf.evaluate(args);
+ }
+
+ @Test(expected = HiveException.class)
+ public void testNumDocsIsLessThanOne() throws Exception {
+ initializeUDFWithoutOptions();
+
+ GenericUDF.DeferredObject[] args = new GenericUDF.DeferredObject[] {VALID_TERM_FREQ,
+ VALID_DOC_LEN, VALID_AVG_DOC_LEN, new GenericUDF.DeferredJavaObject(new Integer(0)),
+ VALID_NUM_DOCS_WITH_TERM};
+
+ udf.evaluate(args);
+ }
+
+ @Test(expected = HiveException.class)
+ public void testNumDocsWithTermIsLessThanOne() throws Exception {
+ initializeUDFWithoutOptions();
+
+ GenericUDF.DeferredObject[] args =
+ new GenericUDF.DeferredObject[] {VALID_TERM_FREQ, VALID_DOC_LEN, VALID_AVG_DOC_LEN,
+ VALID_NUM_DOCS, new GenericUDF.DeferredJavaObject(new Integer(0))};
+
+ udf.evaluate(args);
+ }
+
+ private void initializeUDFWithoutOptions() throws Exception {
+ udf.initialize(
+ new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector,
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector,
+ PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector,
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector});
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ce70aa48/docs/gitbook/SUMMARY.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/SUMMARY.md b/docs/gitbook/SUMMARY.md
index 6c69848..3484bfb 100644
--- a/docs/gitbook/SUMMARY.md
+++ b/docs/gitbook/SUMMARY.md
@@ -66,6 +66,7 @@
* [Feature vectorization](ft_engineering/vectorization.md)
* [Quantify non-number features](ft_engineering/quantify.md)
* [TF-IDF Calculation](ft_engineering/tfidf.md)
+* [BM25](ft_engineering/bm25.md)
## Part IV - Evaluation
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ce70aa48/docs/gitbook/ft_engineering/bm25.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/ft_engineering/bm25.md b/docs/gitbook/ft_engineering/bm25.md
new file mode 100644
index 0000000..4ca029f
--- /dev/null
+++ b/docs/gitbook/ft_engineering/bm25.md
@@ -0,0 +1,197 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+[Okapi BM25](https://en.wikipedia.org/wiki/Okapi_BM25) is a ranking function for documents for a given query.
+
+It can also be used for a better replacement of [TF-IDF](https://en.wikipedia.org/wiki/Tf%E2%80%93idf) and can be used for term-weight for each document.
+
+<!-- toc -->
+
+# The ranking function
+
+Given a query $$Q$$, containing keywords $$q1,....,q_n$$, the BM25 score of a document $$D$$ is:
+
+$$
+score(Q, D) = \sum_{i=1}^{n}IDF(q_{i}) \cdot \frac{tf(q_{i},D) \cdot (k_{1}+1)}{tf(q_{i},D) + k_{1} \cdot (1 - b + b \cdot \frac{|D|}{avgdl})}
+$$
+
+where $$tf(q_{i}, D)$$ is $$q_{i}$$'s term frequency in the document $$D$$, $$|D|$$ is the length of the document $$D$$ in words, and $$avgdl$$ is the average document length in the text collection from which documents are drawn. $$k_{1}$$ and $$b$$ are free parameters, usually chosen, in absence of an advanced optimization, as $$k_{1} \in [1.2,2.0]$$ and $$b = 0.75$$.
+
+BM25 can also be applied for term weighing, showing how important a word is to a document in a collection or corpus, as follows:
+
+$$
+score(t_{i}, D) = IDF(t_{i}) \cdot \frac{tf(t_{i},D) \cdot (k_{1}+1)}{tf(t_{i},D) + k_{1} \cdot (1 - b + b \cdot \frac{|D|}{avgdl})}
+$$
+
+where $$t_{i}$$ is a term appeared in document $$D$$.
+
+# Data preparation
+
+In similar to [TF-IDF](./tfidf), you need to prepare a relation consists of (docid,word) tuples to compute BM25 score.
+
+```sql
+create external table wikipage (
+ docid int,
+ page string
+)
+ROW FORMAT DELIMITED FIELDS TERMINATED BY '|'
+STORED AS TEXTFILE;
+
+cd ~/tmp
+wget https://gist.githubusercontent.com/myui/190b91a3a792ccfceda0/raw/327acd192da4f96da8276dcdff01b19947a4373c/tfidf_test.tsv
+
+LOAD DATA LOCAL INPATH '/home/myui/tmp/tfidf_test.tsv' INTO TABLE wikipage;
+
+create or replace view wikipage_exploded
+as
+select
+ docid,
+ word
+from
+ wikipage LATERAL VIEW explode(tokenize(page,true)) t as word
+where
+ not is_stopword(word);
+```
+
+# Define views of term/doc frequency
+
+```sql
+create or replace view term_frequency
+as
+select
+ t1.docid,
+ t2.word,
+ t2.freq
+from (
+ select
+ docid,
+ tf(word) as word2freq
+ from
+ wikipage_exploded
+ group by
+ docid
+) t1
+LATERAL VIEW explode(word2freq) t2 as word, freq;
+
+create or replace view document_frequency
+as
+select
+ word,
+ count(distinct docid) docs
+from
+ wikipage_exploded
+group by
+ word;
+
+create or replace view doc_len
+as
+select
+ docid,
+ count(1) as dl,
+ avg(count(1)) over () as avgdl,
+ count(distinct docid) over () as total_docs
+from
+ wikipage_exploded
+group by
+ docid
+;
+```
+
+# Compute Okapi BM25 score
+
+BM25 (and TF-IDF) score that represents importance of term for each document is useful for feature weight in feature engineering.
+
+```sql
+create table scores
+as
+select
+ tf.docid,
+ tf.word,
+ bm25(
+ tf.freq,
+ dl.dl,
+ dl.avgdl,
+ dl.total_docs,
+ df.docs
+ -- , '-k1 1.5 -b 0.75'
+ ) as bm25,
+ tfidf(tf.freq, df.docs, dl.total_docs) as tfidf
+from
+ term_frequency tf
+ JOIN document_frequency df ON (tf.word = df.word)
+ JOIN doc_len dl ON (tf.docid = dl.docid)
+;
+```
+
+## Show important terms
+
+```sql
+select
+ docid,
+ to_ordered_list(feature(word,bm25), bm25, '-k 10') as bm25_scores,
+ to_ordered_list(feature(word,tfidf),tfidf, '-k 10') as tfidf_scores
+from
+ scores
+group by
+ docid
+limit 10;
+```
+
+# Retrive relevant documents for a given search terms
+
+You can retrieve relevant documents for a given search query `wisdom, justice, discussion` as follows:
+
+```sql
+WITH scores as (
+ select
+ tf.docid,
+ tf.word,
+ bm25(
+ tf.freq,
+ dl.dl,
+ dl.avgdl,
+ dl.total_docs,
+ df.docs
+ -- , '-k1 1.5 -b 0.75'
+ ) as bm25,
+ tfidf(tf.freq, df.docs, dl.total_docs) as tfidf
+ from
+ term_frequency tf
+ JOIN document_frequency df ON (tf.word = df.word)
+ JOIN doc_len dl ON (tf.docid = dl.docid)
+ where
+ tf.word in ('wisdom', 'justice', 'discussion')
+)
+select
+ docid,
+ sum(bm25) as score
+from
+ scores
+group by
+ docid
+order by
+ score DESC
+LIMIT 10
+;
+```
+
+| docid | score |
+|:-:|:-:|
+| 1 | 0.14190456024682774 |
+| 2 | 0.02197354085722251 |
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ce70aa48/docs/gitbook/misc/funcs.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/misc/funcs.md b/docs/gitbook/misc/funcs.md
index c80128b..a0c7d29 100644
--- a/docs/gitbook/misc/funcs.md
+++ b/docs/gitbook/misc/funcs.md
@@ -532,5 +532,7 @@ This page describes a list of Hivemall functions. See also a [list of generic Hi
WITH dual AS (SELECT 1) SELECT lr_datagen('-n_examples 1k -n_features 10') FROM dual;
```
+- `bm25(int termFrequency, int docLength, double avgDocLength, int numDocs, int numDocsWithTerm [, const string options])` - Return an Okapi BM25 score in double
+
- `tf(string text)` - Return a term frequency in <string, float>
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ce70aa48/resources/ddl/define-all-as-permanent.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive
index f359aaf..69dcf69 100644
--- a/resources/ddl/define-all-as-permanent.hive
+++ b/resources/ddl/define-all-as-permanent.hive
@@ -343,6 +343,9 @@ CREATE FUNCTION populate_not_in as 'hivemall.ftvec.ranking.PopulateNotInUDTF' US
DROP FUNCTION IF EXISTS tf;
CREATE FUNCTION tf as 'hivemall.ftvec.text.TermFrequencyUDAF' USING JAR '${hivemall_jar}';
+DROP FUNCTION IF EXISTS bm25;
+CREATE FUNCTION bm25 as 'hivemall.ftvec.text.OkapiBM25UDF' USING JAR '${hivemall_jar}';
+
--------------------------
-- Regression functions --
--------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ce70aa48/resources/ddl/define-all.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive
index aed1b2f..f39aea3 100644
--- a/resources/ddl/define-all.hive
+++ b/resources/ddl/define-all.hive
@@ -339,6 +339,9 @@ create temporary function populate_not_in as 'hivemall.ftvec.ranking.PopulateNot
drop temporary function if exists tf;
create temporary function tf as 'hivemall.ftvec.text.TermFrequencyUDAF';
+drop temporary function if exists bm25;
+create temporary function bm25 as 'hivemall.ftvec.text.OkapiBM25UDF';
+
--------------------------
-- Regression functions --
--------------------------
@@ -881,3 +884,4 @@ log(10, n_docs / max2(1,df_t)) + 1.0;
create temporary macro tfidf(tf FLOAT, df_t DOUBLE, n_docs DOUBLE)
tf * (log(10, n_docs / max2(1,df_t)) + 1.0);
+
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ce70aa48/resources/ddl/define-all.spark
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark
index dcb368e..4d46694 100644
--- a/resources/ddl/define-all.spark
+++ b/resources/ddl/define-all.spark
@@ -342,6 +342,9 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION populate_not_in AS 'hivemall.ftvec.ran
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS tf")
sqlContext.sql("CREATE TEMPORARY FUNCTION tf AS 'hivemall.ftvec.text.TermFrequencyUDAF'")
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS bm25")
+sqlContext.sql("CREATE TEMPORARY FUNCTION bm25 AS 'hivemall.ftvec.text.OkapiBM25UDF'")
+
/**
* Regression functions
*/
@@ -834,3 +837,4 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION bloom_or AS 'hivemall.sketch.bloom.Blo
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS bloom_contains_any")
sqlContext.sql("CREATE TEMPORARY FUNCTION bloom_contains_any AS 'hivemall.sketch.bloom.BloomContainsAnyUDF'")
+