Repository: incubator-hivemall Updated Branches: refs/heads/master 4f795cb9a -> ce70aa482
[HIVEMALL-196] Support BM25 scoring ## What changes were proposed in this pull request? Adding scoring function Okapi BM25 as a UDF ## What type of PR is it? Feature ## What is the Jira issue? https://issues.apache.org/jira/projects/HIVEMALL/issues/HIVEMALL-196 ## How was this patch tested? 1. Unit testing 2. Manual testing on Hive ## How to use this feature? This new `okapi_bm25` function requires 5 mandatory arguments and 2 optional hyperparameters: 1. raw frequency count of a term in a given document 2. length of the given document 3. average length of a document in the corpus 4. number of documents in the corpus 5. number of documents containing the term, i.e. document frequency 6. (*optional*) k1 - a smoothing hyperparameter 7. (*optional*) b - a smoothing hyperparameter ### Step 1: Count frequency of terms ```sql create or replace view frequency as select docid, word, count(*) as freq from test_corpus_exploded group by docid, word ; ``` ### Step 2: Calculate document lengths ```sql create or replace view doc_len as select docid, count(1) as cnt from test_corpus_exploded group by docid ; ``` ### Step 3: Calculate document frequency ```sql create or replace view document_frequency as select word, count(distinct docid) docs from test_corpus_exploded group by word ; ``` ### Step 4: Set number of documents ```sql set hivevar:n_docs=3; ``` ### Step 5: Use `okapi_bm25` ```sql create or replace view bm25 as with tmp as ( select avg(cnt) as avgdl from doc_len ) select f.docid, f.word, okapi_bm25( CAST(f.freq AS INT), dl.cnt, CAST(tmp.avgdl AS DOUBLE), ${n_docs}, df.docs, '-k1 1.5 -b 0.75' ) as score from frequency f JOIN document_frequency df ON (f.word = df.word) JOIN doc_len dl ON (f.docid = dl.docid) CROSS JOIN tmp ORDER BY score desc; ``` ## Checklist (Please remove this section if not needed; check `x` for YES, blank for NO) - [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit? - [x] Did you run system tests on Hive (or Spark)? Author: Jackson Huang <[email protected]> Author: Makoto Yui <[email protected]> Closes #163 from jaxony/feature/bm25. Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/ce70aa48 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/ce70aa48 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/ce70aa48 Branch: refs/heads/master Commit: ce70aa482c766f4d45160c850d28794c39028059 Parents: 4f795cb Author: Jackson Huang <[email protected]> Authored: Fri Nov 2 19:35:13 2018 +0900 Committer: Makoto Yui <[email protected]> Committed: Fri Nov 2 19:35:13 2018 +0900 ---------------------------------------------------------------------- core/src/main/java/hivemall/UDFWithOptions.java | 24 +++ .../java/hivemall/ftvec/text/OkapiBM25UDF.java | 172 ++++++++++++++++ .../java/hivemall/utils/lang/Primitives.java | 4 + .../hivemall/ftvec/text/OkapiBM25UDFTest.java | 193 ++++++++++++++++++ docs/gitbook/SUMMARY.md | 1 + docs/gitbook/ft_engineering/bm25.md | 197 +++++++++++++++++++ docs/gitbook/misc/funcs.md | 2 + resources/ddl/define-all-as-permanent.hive | 3 + resources/ddl/define-all.hive | 4 + resources/ddl/define-all.spark | 4 + 10 files changed, 604 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ce70aa48/core/src/main/java/hivemall/UDFWithOptions.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/UDFWithOptions.java b/core/src/main/java/hivemall/UDFWithOptions.java index 9908cd9..04d6fdc 100644 --- a/core/src/main/java/hivemall/UDFWithOptions.java +++ b/core/src/main/java/hivemall/UDFWithOptions.java @@ -112,6 +112,30 @@ public abstract class UDFWithOptions extends GenericUDF { return cl; } + /** + * Raise {@link UDFArgumentException} if the given condition is false. + * + * @throws UDFArgumentException + */ + protected static void assumeTrue(final boolean condition, @Nonnull final String errMsg) + throws UDFArgumentException { + if (!condition) { + throw new UDFArgumentException(errMsg); + } + } + + /** + * Raise {@link UDFArgumentException} if the given condition is true. + * + * @throws UDFArgumentException + */ + protected static void assumeFalse(final boolean condition, @Nonnull final String errMsg) + throws UDFArgumentException { + if (condition) { + throw new UDFArgumentException(errMsg); + } + } + @Nonnull protected abstract CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ce70aa48/core/src/main/java/hivemall/ftvec/text/OkapiBM25UDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/text/OkapiBM25UDF.java b/core/src/main/java/hivemall/ftvec/text/OkapiBM25UDF.java new file mode 100644 index 0000000..cd36d6f --- /dev/null +++ b/core/src/main/java/hivemall/ftvec/text/OkapiBM25UDF.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.ftvec.text; + +import hivemall.UDFWithOptions; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Primitives; +import hivemall.utils.lang.StringUtils; + +import javax.annotation.Nonnull; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; + +@Description(name = "bm25", + value = "_FUNC_(double termFrequency, int docLength, double avgDocLength, int numDocs, int numDocsWithTerm [, const string options]) - Return an Okapi BM25 score in double") +@UDFType(deterministic = true, stateful = false) +public final class OkapiBM25UDF extends UDFWithOptions { + + private double k1 = 1.2d; + private double b = 0.75d; + + // BM25+ https://en.wikipedia.org/wiki/Okapi_BM25#General_references + private double delta = 0.d; + + // epsilon in https://en.wikipedia.org/wiki/Okapi_BM25#The_ranking_function + private double minIDF = 1e-8; + + private PrimitiveObjectInspector frequencyOI; + private PrimitiveObjectInspector docLengthOI; + private PrimitiveObjectInspector averageDocLengthOI; + private PrimitiveObjectInspector numDocsOI; + private PrimitiveObjectInspector numDocsWithTermOI; + + @Nonnull + private final DoubleWritable result = new DoubleWritable(); + + public OkapiBM25UDF() {} + + @Override + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("k1", true, + "Hyperparameter with type double, usually in range 1.2 and 2.0 [default: 1.2]"); + opts.addOption("b", true, + "Hyperparameter with type double in range 0.0 and 1.0 [default: 0.75]"); + opts.addOption("d", "delta", true, "Hyperparameter delta of BM25+ [default: 0.0]"); + opts.addOption("min_idf", "epsilon", true, "Hyperparameter delta of BM25+ [default: 1e-8]"); + return opts; + } + + @Override + protected CommandLine processOptions(@Nonnull String opts) throws UDFArgumentException { + CommandLine cl = parseOptions(opts); + + this.k1 = Primitives.parseDouble(cl.getOptionValue("k1"), k1); + + if (Primitives.isFinite(k1) == false || k1 < 0.0) { + throw new UDFArgumentException("k1 must be a non-negative finite value: " + k1); + } + + this.b = Primitives.parseDouble(cl.getOptionValue("b"), b); + if (Double.isNaN(b) || b < 0.0 || b > 1.0) { + throw new UDFArgumentException( + "b1 hyperparameter must be in the range [0.0, 1.0]: " + b); + } + + this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), delta); + if (Primitives.isFinite(delta) == false) { + throw new UDFArgumentException("Delta must be a finite value: " + delta); + } + + this.minIDF = Primitives.parseDouble(cl.getOptionValue("min_idf"), minIDF); + if (minIDF < 0.d) { + throw new UDFArgumentException("min_idf must not be negative value: " + minIDF); + } + + return cl; + } + + @Override + public ObjectInspector initialize(@Nonnull ObjectInspector[] argOIs) + throws UDFArgumentException { + final int numArgOIs = argOIs.length; + if (numArgOIs < 5) { + throw new UDFArgumentException("argOIs.length must be greater than or equal to 5"); + } else if (numArgOIs == 6) { + String opts = HiveUtils.getConstString(argOIs[5]); + processOptions(opts); + } + + this.frequencyOI = HiveUtils.asDoubleCompatibleOI(argOIs[0]); + this.docLengthOI = HiveUtils.asIntegerOI(argOIs[1]); + this.averageDocLengthOI = HiveUtils.asDoubleCompatibleOI(argOIs[2]); + this.numDocsOI = HiveUtils.asIntegerOI(argOIs[3]); + this.numDocsWithTermOI = HiveUtils.asIntegerOI(argOIs[4]); + + return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; + } + + @Override + public DoubleWritable evaluate(@Nonnull DeferredObject[] arguments) throws HiveException { + Object arg0 = arguments[0].get(); + Object arg1 = arguments[1].get(); + Object arg2 = arguments[2].get(); + Object arg3 = arguments[3].get(); + Object arg4 = arguments[4].get(); + + if (arg0 == null || arg1 == null || arg2 == null || arg3 == null || arg4 == null) { + throw new UDFArgumentException("Required arguments cannot be null"); + } + + double frequency = PrimitiveObjectInspectorUtils.getDouble(arg0, frequencyOI); + int docLength = PrimitiveObjectInspectorUtils.getInt(arg1, docLengthOI); + double averageDocLength = PrimitiveObjectInspectorUtils.getDouble(arg2, averageDocLengthOI); + int numDocs = PrimitiveObjectInspectorUtils.getInt(arg3, numDocsOI); + int numDocsWithTerm = PrimitiveObjectInspectorUtils.getInt(arg4, numDocsWithTermOI); + + assumeFalse(frequency < 0, "#frequency must be positive"); + assumeFalse(docLength < 1, "#docLength must be greater than or equal to 1"); + assumeFalse(averageDocLength <= 0.0, "#averageDocLength must be positive"); + assumeFalse(numDocs < 1, "#numDocs must be greater than or equal to 1"); + assumeFalse(numDocsWithTerm < 1, "#numDocsWithTerm must be greater than or equal to 1"); + + double v = bm25(frequency, docLength, averageDocLength, numDocs, numDocsWithTerm); + result.set(v); + return result; + } + + private double bm25(final double tf, final int docLength, final double averageDocLength, + final int numDocs, final int numDocsWithTerm) { + double numerator = tf * (k1 + 1); + double denominator = tf + k1 * (1 - b + b * docLength / averageDocLength); + double idf = Math.max(minIDF, idf(numDocs, numDocsWithTerm)); + return idf * (numerator / denominator + delta); + } + + private static double idf(final int numDocs, final int numDocsWithTerm) { + return Math.log10(1.0d + (numDocs - numDocsWithTerm + 0.5d) / (numDocsWithTerm + 0.5d)); + } + + @Override + public String getDisplayString(String[] children) { + return "bm25(" + StringUtils.join(children, ',') + ")"; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ce70aa48/core/src/main/java/hivemall/utils/lang/Primitives.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/lang/Primitives.java b/core/src/main/java/hivemall/utils/lang/Primitives.java index ab3be9a..1c05102 100644 --- a/core/src/main/java/hivemall/utils/lang/Primitives.java +++ b/core/src/main/java/hivemall/utils/lang/Primitives.java @@ -78,6 +78,10 @@ public final class Primitives { return v.doubleValue(); } + public static boolean isFinite(final double value) { + return Double.NEGATIVE_INFINITY < value && value < Double.POSITIVE_INFINITY; + } + public static int compare(final int x, final int y) { return (x < y) ? -1 : ((x == y) ? 0 : 1); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ce70aa48/core/src/test/java/hivemall/ftvec/text/OkapiBM25UDFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/ftvec/text/OkapiBM25UDFTest.java b/core/src/test/java/hivemall/ftvec/text/OkapiBM25UDFTest.java new file mode 100644 index 0000000..8ddba23 --- /dev/null +++ b/core/src/test/java/hivemall/ftvec/text/OkapiBM25UDFTest.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.ftvec.text; + +import static org.junit.Assert.assertEquals; + +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.hadoop.WritableUtils; + +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.junit.Before; +import org.junit.Test; + +public class OkapiBM25UDFTest { + + private static final double EPSILON = 1e-8; + private static final GenericUDF.DeferredJavaObject VALID_TERM_FREQ = + new GenericUDF.DeferredJavaObject(new Integer(3)); + private static final GenericUDF.DeferredJavaObject VALID_DOC_LEN = + new GenericUDF.DeferredJavaObject(new Integer(9)); + private static final GenericUDF.DeferredJavaObject VALID_AVG_DOC_LEN = + new GenericUDF.DeferredJavaObject(new Double(10.35)); + private static final GenericUDF.DeferredJavaObject VALID_NUM_DOCS = + new GenericUDF.DeferredJavaObject(new Integer(20)); + private static final GenericUDF.DeferredJavaObject VALID_NUM_DOCS_WITH_TERM = + new GenericUDF.DeferredJavaObject(new Integer(5)); + + private OkapiBM25UDF udf = null; + + + @Before + public void init() throws Exception { + udf = new OkapiBM25UDF(); + } + + @Test + public void testEvaluate() throws Exception { + + initializeUDFWithoutOptions(); + + GenericUDF.DeferredObject[] args = new GenericUDF.DeferredObject[] {VALID_TERM_FREQ, + VALID_DOC_LEN, VALID_AVG_DOC_LEN, VALID_NUM_DOCS, VALID_NUM_DOCS_WITH_TERM}; + + DoubleWritable expected = WritableUtils.val(0.940637195691); + DoubleWritable actual = udf.evaluate(args); + assertEquals(expected.get(), actual.get(), EPSILON); + } + + @Test + public void testEvaluateWithCustomK1() throws Exception { + + udf.initialize( + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector, + HiveUtils.getConstStringObjectInspector("-k1 1.5")}); + + GenericUDF.DeferredObject[] args = new GenericUDF.DeferredObject[] {VALID_TERM_FREQ, + VALID_DOC_LEN, VALID_AVG_DOC_LEN, VALID_NUM_DOCS, VALID_NUM_DOCS_WITH_TERM}; + + DoubleWritable expected = WritableUtils.val(1.00244958206); + DoubleWritable actual = udf.evaluate(args); + assertEquals(expected.get(), actual.get(), EPSILON); + } + + @Test + public void testEvaluateWithCustomB() throws Exception { + + udf.initialize( + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector, + HiveUtils.getConstStringObjectInspector("-b 0.8")}); + + GenericUDF.DeferredObject[] args = new GenericUDF.DeferredObject[] {VALID_TERM_FREQ, + VALID_DOC_LEN, VALID_AVG_DOC_LEN, VALID_NUM_DOCS, VALID_NUM_DOCS_WITH_TERM}; + + DoubleWritable expected = WritableUtils.val(0.942443797219); + DoubleWritable actual = udf.evaluate(args); + assertEquals(expected.get(), actual.get(), EPSILON); + } + + @Test(expected = HiveException.class) + public void testInputArgIsNull() throws Exception { + + initializeUDFWithoutOptions(); + + GenericUDF.DeferredObject[] args = + new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(null), + VALID_DOC_LEN, VALID_AVG_DOC_LEN, VALID_NUM_DOCS, VALID_NUM_DOCS_WITH_TERM}; + + udf.evaluate(args); + } + + @Test(expected = HiveException.class) + public void testTermFrequencyIsNegative() throws Exception { + initializeUDFWithoutOptions(); + + GenericUDF.DeferredObject[] args = + new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(new Integer(-1)), + VALID_DOC_LEN, VALID_AVG_DOC_LEN, VALID_NUM_DOCS, VALID_NUM_DOCS_WITH_TERM}; + + udf.evaluate(args); + } + + @Test(expected = HiveException.class) + public void testDocLengthIsLessThanOne() throws Exception { + initializeUDFWithoutOptions(); + + GenericUDF.DeferredObject[] args = new GenericUDF.DeferredObject[] {VALID_TERM_FREQ, + new GenericUDF.DeferredJavaObject(new Integer(0)), VALID_AVG_DOC_LEN, + VALID_NUM_DOCS, VALID_NUM_DOCS_WITH_TERM}; + + udf.evaluate(args); + } + + @Test(expected = HiveException.class) + public void testAvgDocLengthIsNegative() throws Exception { + initializeUDFWithoutOptions(); + + GenericUDF.DeferredObject[] args = new GenericUDF.DeferredObject[] {VALID_TERM_FREQ, + VALID_DOC_LEN, new GenericUDF.DeferredJavaObject(new Double(-10)), VALID_NUM_DOCS, + VALID_NUM_DOCS_WITH_TERM}; + + udf.evaluate(args); + } + + @Test(expected = HiveException.class) + public void testAvgDocLengthIsZero() throws Exception { + initializeUDFWithoutOptions(); + + GenericUDF.DeferredObject[] args = new GenericUDF.DeferredObject[] {VALID_TERM_FREQ, + VALID_DOC_LEN, new GenericUDF.DeferredJavaObject(new Double(0.0)), VALID_NUM_DOCS, + VALID_NUM_DOCS_WITH_TERM}; + + udf.evaluate(args); + } + + @Test(expected = HiveException.class) + public void testNumDocsIsLessThanOne() throws Exception { + initializeUDFWithoutOptions(); + + GenericUDF.DeferredObject[] args = new GenericUDF.DeferredObject[] {VALID_TERM_FREQ, + VALID_DOC_LEN, VALID_AVG_DOC_LEN, new GenericUDF.DeferredJavaObject(new Integer(0)), + VALID_NUM_DOCS_WITH_TERM}; + + udf.evaluate(args); + } + + @Test(expected = HiveException.class) + public void testNumDocsWithTermIsLessThanOne() throws Exception { + initializeUDFWithoutOptions(); + + GenericUDF.DeferredObject[] args = + new GenericUDF.DeferredObject[] {VALID_TERM_FREQ, VALID_DOC_LEN, VALID_AVG_DOC_LEN, + VALID_NUM_DOCS, new GenericUDF.DeferredJavaObject(new Integer(0))}; + + udf.evaluate(args); + } + + private void initializeUDFWithoutOptions() throws Exception { + udf.initialize( + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector}); + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ce70aa48/docs/gitbook/SUMMARY.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/SUMMARY.md b/docs/gitbook/SUMMARY.md index 6c69848..3484bfb 100644 --- a/docs/gitbook/SUMMARY.md +++ b/docs/gitbook/SUMMARY.md @@ -66,6 +66,7 @@ * [Feature vectorization](ft_engineering/vectorization.md) * [Quantify non-number features](ft_engineering/quantify.md) * [TF-IDF Calculation](ft_engineering/tfidf.md) +* [BM25](ft_engineering/bm25.md) ## Part IV - Evaluation http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ce70aa48/docs/gitbook/ft_engineering/bm25.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/ft_engineering/bm25.md b/docs/gitbook/ft_engineering/bm25.md new file mode 100644 index 0000000..4ca029f --- /dev/null +++ b/docs/gitbook/ft_engineering/bm25.md @@ -0,0 +1,197 @@ +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> + +[Okapi BM25](https://en.wikipedia.org/wiki/Okapi_BM25) is a ranking function for documents for a given query. + +It can also be used for a better replacement of [TF-IDF](https://en.wikipedia.org/wiki/Tf%E2%80%93idf) and can be used for term-weight for each document. + +<!-- toc --> + +# The ranking function + +Given a query $$Q$$, containing keywords $$q1,....,q_n$$, the BM25 score of a document $$D$$ is: + +$$ +score(Q, D) = \sum_{i=1}^{n}IDF(q_{i}) \cdot \frac{tf(q_{i},D) \cdot (k_{1}+1)}{tf(q_{i},D) + k_{1} \cdot (1 - b + b \cdot \frac{|D|}{avgdl})} +$$ + +where $$tf(q_{i}, D)$$ is $$q_{i}$$'s term frequency in the document $$D$$, $$|D|$$ is the length of the document $$D$$ in words, and $$avgdl$$ is the average document length in the text collection from which documents are drawn. $$k_{1}$$ and $$b$$ are free parameters, usually chosen, in absence of an advanced optimization, as $$k_{1} \in [1.2,2.0]$$ and $$b = 0.75$$. + +BM25 can also be applied for term weighing, showing how important a word is to a document in a collection or corpus, as follows: + +$$ +score(t_{i}, D) = IDF(t_{i}) \cdot \frac{tf(t_{i},D) \cdot (k_{1}+1)}{tf(t_{i},D) + k_{1} \cdot (1 - b + b \cdot \frac{|D|}{avgdl})} +$$ + +where $$t_{i}$$ is a term appeared in document $$D$$. + +# Data preparation + +In similar to [TF-IDF](./tfidf), you need to prepare a relation consists of (docid,word) tuples to compute BM25 score. + +```sql +create external table wikipage ( + docid int, + page string +) +ROW FORMAT DELIMITED FIELDS TERMINATED BY '|' +STORED AS TEXTFILE; + +cd ~/tmp +wget https://gist.githubusercontent.com/myui/190b91a3a792ccfceda0/raw/327acd192da4f96da8276dcdff01b19947a4373c/tfidf_test.tsv + +LOAD DATA LOCAL INPATH '/home/myui/tmp/tfidf_test.tsv' INTO TABLE wikipage; + +create or replace view wikipage_exploded +as +select + docid, + word +from + wikipage LATERAL VIEW explode(tokenize(page,true)) t as word +where + not is_stopword(word); +``` + +# Define views of term/doc frequency + +```sql +create or replace view term_frequency +as +select + t1.docid, + t2.word, + t2.freq +from ( + select + docid, + tf(word) as word2freq + from + wikipage_exploded + group by + docid +) t1 +LATERAL VIEW explode(word2freq) t2 as word, freq; + +create or replace view document_frequency +as +select + word, + count(distinct docid) docs +from + wikipage_exploded +group by + word; + +create or replace view doc_len +as +select + docid, + count(1) as dl, + avg(count(1)) over () as avgdl, + count(distinct docid) over () as total_docs +from + wikipage_exploded +group by + docid +; +``` + +# Compute Okapi BM25 score + +BM25 (and TF-IDF) score that represents importance of term for each document is useful for feature weight in feature engineering. + +```sql +create table scores +as +select + tf.docid, + tf.word, + bm25( + tf.freq, + dl.dl, + dl.avgdl, + dl.total_docs, + df.docs + -- , '-k1 1.5 -b 0.75' + ) as bm25, + tfidf(tf.freq, df.docs, dl.total_docs) as tfidf +from + term_frequency tf + JOIN document_frequency df ON (tf.word = df.word) + JOIN doc_len dl ON (tf.docid = dl.docid) +; +``` + +## Show important terms + +```sql +select + docid, + to_ordered_list(feature(word,bm25), bm25, '-k 10') as bm25_scores, + to_ordered_list(feature(word,tfidf),tfidf, '-k 10') as tfidf_scores +from + scores +group by + docid +limit 10; +``` + +# Retrive relevant documents for a given search terms + +You can retrieve relevant documents for a given search query `wisdom, justice, discussion` as follows: + +```sql +WITH scores as ( + select + tf.docid, + tf.word, + bm25( + tf.freq, + dl.dl, + dl.avgdl, + dl.total_docs, + df.docs + -- , '-k1 1.5 -b 0.75' + ) as bm25, + tfidf(tf.freq, df.docs, dl.total_docs) as tfidf + from + term_frequency tf + JOIN document_frequency df ON (tf.word = df.word) + JOIN doc_len dl ON (tf.docid = dl.docid) + where + tf.word in ('wisdom', 'justice', 'discussion') +) +select + docid, + sum(bm25) as score +from + scores +group by + docid +order by + score DESC +LIMIT 10 +; +``` + +| docid | score | +|:-:|:-:| +| 1 | 0.14190456024682774 | +| 2 | 0.02197354085722251 | \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ce70aa48/docs/gitbook/misc/funcs.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/misc/funcs.md b/docs/gitbook/misc/funcs.md index c80128b..a0c7d29 100644 --- a/docs/gitbook/misc/funcs.md +++ b/docs/gitbook/misc/funcs.md @@ -532,5 +532,7 @@ This page describes a list of Hivemall functions. See also a [list of generic Hi WITH dual AS (SELECT 1) SELECT lr_datagen('-n_examples 1k -n_features 10') FROM dual; ``` +- `bm25(int termFrequency, int docLength, double avgDocLength, int numDocs, int numDocsWithTerm [, const string options])` - Return an Okapi BM25 score in double + - `tf(string text)` - Return a term frequency in <string, float> http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ce70aa48/resources/ddl/define-all-as-permanent.hive ---------------------------------------------------------------------- diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive index f359aaf..69dcf69 100644 --- a/resources/ddl/define-all-as-permanent.hive +++ b/resources/ddl/define-all-as-permanent.hive @@ -343,6 +343,9 @@ CREATE FUNCTION populate_not_in as 'hivemall.ftvec.ranking.PopulateNotInUDTF' US DROP FUNCTION IF EXISTS tf; CREATE FUNCTION tf as 'hivemall.ftvec.text.TermFrequencyUDAF' USING JAR '${hivemall_jar}'; +DROP FUNCTION IF EXISTS bm25; +CREATE FUNCTION bm25 as 'hivemall.ftvec.text.OkapiBM25UDF' USING JAR '${hivemall_jar}'; + -------------------------- -- Regression functions -- -------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ce70aa48/resources/ddl/define-all.hive ---------------------------------------------------------------------- diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive index aed1b2f..f39aea3 100644 --- a/resources/ddl/define-all.hive +++ b/resources/ddl/define-all.hive @@ -339,6 +339,9 @@ create temporary function populate_not_in as 'hivemall.ftvec.ranking.PopulateNot drop temporary function if exists tf; create temporary function tf as 'hivemall.ftvec.text.TermFrequencyUDAF'; +drop temporary function if exists bm25; +create temporary function bm25 as 'hivemall.ftvec.text.OkapiBM25UDF'; + -------------------------- -- Regression functions -- -------------------------- @@ -881,3 +884,4 @@ log(10, n_docs / max2(1,df_t)) + 1.0; create temporary macro tfidf(tf FLOAT, df_t DOUBLE, n_docs DOUBLE) tf * (log(10, n_docs / max2(1,df_t)) + 1.0); + http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ce70aa48/resources/ddl/define-all.spark ---------------------------------------------------------------------- diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark index dcb368e..4d46694 100644 --- a/resources/ddl/define-all.spark +++ b/resources/ddl/define-all.spark @@ -342,6 +342,9 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION populate_not_in AS 'hivemall.ftvec.ran sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS tf") sqlContext.sql("CREATE TEMPORARY FUNCTION tf AS 'hivemall.ftvec.text.TermFrequencyUDAF'") +sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS bm25") +sqlContext.sql("CREATE TEMPORARY FUNCTION bm25 AS 'hivemall.ftvec.text.OkapiBM25UDF'") + /** * Regression functions */ @@ -834,3 +837,4 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION bloom_or AS 'hivemall.sketch.bloom.Blo sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS bloom_contains_any") sqlContext.sql("CREATE TEMPORARY FUNCTION bloom_contains_any AS 'hivemall.sketch.bloom.BloomContainsAnyUDF'") +
