This is an automated email from the ASF dual-hosted git repository. myui pushed a commit to branch HIVEMALL-301-tfidf in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git
commit 1a62e7645a7b80bd16a15a0e3e567b08940ad2be Author: Makoto Yui <[email protected]> AuthorDate: Thu Mar 25 17:42:55 2021 +0900 Replaced tfidf macro with UDF --- .../main/java/hivemall/ftvec/text/TfIdfUDF.java | 104 +++++++++++++++++++++ resources/ddl/define-all-as-permanent.hive | 4 + resources/ddl/define-all.hive | 12 +-- resources/ddl/define-all.spark | 3 + resources/ddl/define-macros.hive | 10 -- 5 files changed, 114 insertions(+), 19 deletions(-) diff --git a/core/src/main/java/hivemall/ftvec/text/TfIdfUDF.java b/core/src/main/java/hivemall/ftvec/text/TfIdfUDF.java new file mode 100644 index 0000000..f7cbb3b --- /dev/null +++ b/core/src/main/java/hivemall/ftvec/text/TfIdfUDF.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.ftvec.text; + +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.StringUtils; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; + +@Description(name = "tfidf", + value = "_FUNC_(double termFrequency, long numDocs, const long totalNumDocs) " + + "- Return an TFIDF score in double.") +@UDFType(deterministic = true, stateful = false) +public final class TfIdfUDF extends GenericUDF { + + private DoubleObjectInspector tfOI; + private LongObjectInspector numDocsOI; + private LongObjectInspector totalNumDocsOI; + + @Nonnull + private final DoubleWritable result = new DoubleWritable(); + + @Override + public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + if (argOIs.length != 3) { + throw new UDFArgumentLengthException( + "tfidf takes exactly three arguments but got " + argOIs.length); + } + + this.tfOI = HiveUtils.asDoubleOI(argOIs[0]); + this.numDocsOI = HiveUtils.asLongOI(argOIs[1]); + this.totalNumDocsOI = HiveUtils.asLongOI(argOIs[2]); + + return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; + } + + + @Override + public Object evaluate(DeferredObject[] arguments) throws HiveException { + Object arg0 = getObject(arguments, 0); + Object arg1 = getObject(arguments, 1); + Object arg2 = getObject(arguments, 2); + + double tf = PrimitiveObjectInspectorUtils.getDouble(arg0, tfOI); + long numDocs = PrimitiveObjectInspectorUtils.getLong(arg1, numDocsOI); + long totalNumDocs = PrimitiveObjectInspectorUtils.getLong(arg2, totalNumDocsOI); + + // basic IDF + // idf = log(N/n_t) + // IDF with smoothing + // idf = log(N/(1+n_t))+1 + // idf = log(N/max(1,n_t))+1 -- about zero division + double idf = (Math.log10(totalNumDocs / Math.max(1L, numDocs)) + 1.0d); + double tfidf = tf * idf; + result.set(tfidf); + return result; + } + + @Nonnull + protected static Object getObject(@Nonnull final DeferredObject[] arguments, + @Nonnegative final int index) throws HiveException { + Object obj = arguments[index].get(); + if (obj == null) { + throw new UDFArgumentException(String.format("%d-th argument MUST not be null", index)); + } + return obj; + } + + @Override + public String getDisplayString(String[] children) { + return "tfidf(" + StringUtils.join(children, ',') + ")"; + } + +} diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive index f995d55..c5f2669 100644 --- a/resources/ddl/define-all-as-permanent.hive +++ b/resources/ddl/define-all-as-permanent.hive @@ -349,6 +349,9 @@ CREATE FUNCTION tf as 'hivemall.ftvec.text.TermFrequencyUDAF' USING JAR '${hivem DROP FUNCTION IF EXISTS bm25; CREATE FUNCTION bm25 as 'hivemall.ftvec.text.OkapiBM25UDF' USING JAR '${hivemall_jar}'; +DROP FUNCTION IF EXISTS tfidf; +CREATE FUNCTION tfidf as 'hivemall.ftvec.text.TfIdfUDF' USING JAR '${hivemall_jar}'; + -------------------------- -- Regression functions -- -------------------------- @@ -920,3 +923,4 @@ CREATE FUNCTION xgboost_predict_one AS 'hivemall.xgboost.XGBoostPredictOneUDTF' DROP FUNCTION xgboost_predict_triple; CREATE FUNCTION xgboost_predict_triple AS 'hivemall.xgboost.XGBoostPredictTripleUDTF' USING JAR '${hivemall_jar}'; + diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive index bf9bc7c..b55c59b 100644 --- a/resources/ddl/define-all.hive +++ b/resources/ddl/define-all.hive @@ -345,6 +345,9 @@ create temporary function tf as 'hivemall.ftvec.text.TermFrequencyUDAF'; drop temporary function if exists bm25; create temporary function bm25 as 'hivemall.ftvec.text.OkapiBM25UDF'; +drop temporary function if exists tfidf; +create temporary function tfidf as 'hivemall.ftvec.text.TfIdfUDF'; + -------------------------- -- Regression functions -- -------------------------- @@ -909,12 +912,3 @@ if(x>y,x,y); create temporary macro min2(x DOUBLE, y DOUBLE) if(x<y,x,y); --------------------------- --- Statistics functions -- --------------------------- - -create temporary macro idf(df_t DOUBLE, n_docs DOUBLE) -log(10, n_docs / max2(1,df_t)) + 1.0; - -create temporary macro tfidf(tf FLOAT, df_t DOUBLE, n_docs DOUBLE) -tf * (log(10, n_docs / max2(1,df_t)) + 1.0); diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark index 8529134..91c6350 100644 --- a/resources/ddl/define-all.spark +++ b/resources/ddl/define-all.spark @@ -348,6 +348,9 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION tf AS 'hivemall.ftvec.text.TermFrequen sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS bm25") sqlContext.sql("CREATE TEMPORARY FUNCTION bm25 AS 'hivemall.ftvec.text.OkapiBM25UDF'") +sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS tfidf") +sqlContext.sql("CREATE TEMPORARY FUNCTION tfidf AS 'hivemall.ftvec.text.TfIdfUDF'") + /** * Regression functions */ diff --git a/resources/ddl/define-macros.hive b/resources/ddl/define-macros.hive index ff36a44..84c4e06 100644 --- a/resources/ddl/define-macros.hive +++ b/resources/ddl/define-macros.hive @@ -19,16 +19,6 @@ create temporary macro min2(x DOUBLE, y DOUBLE) if(x<y,x,y); -------------------------- --- Statistics functions -- --------------------------- - -create temporary macro idf(df_t DOUBLE, n_docs DOUBLE) -log(10, n_docs / max2(1,df_t)) + 1.0; - -create temporary macro tfidf(tf FLOAT, df_t DOUBLE, n_docs DOUBLE) -tf * (log(10, n_docs / max2(1,df_t)) + 1.0); - --------------------------- -- Evaluation functions -- --------------------------
