This is an automated email from the ASF dual-hosted git repository.

myui pushed a commit to branch HIVEMALL-301-tfidf
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git

commit 1a62e7645a7b80bd16a15a0e3e567b08940ad2be
Author: Makoto Yui <[email protected]>
AuthorDate: Thu Mar 25 17:42:55 2021 +0900

    Replaced tfidf macro with UDF
---
 .../main/java/hivemall/ftvec/text/TfIdfUDF.java    | 104 +++++++++++++++++++++
 resources/ddl/define-all-as-permanent.hive         |   4 +
 resources/ddl/define-all.hive                      |  12 +--
 resources/ddl/define-all.spark                     |   3 +
 resources/ddl/define-macros.hive                   |  10 --
 5 files changed, 114 insertions(+), 19 deletions(-)

diff --git a/core/src/main/java/hivemall/ftvec/text/TfIdfUDF.java 
b/core/src/main/java/hivemall/ftvec/text/TfIdfUDF.java
new file mode 100644
index 0000000..f7cbb3b
--- /dev/null
+++ b/core/src/main/java/hivemall/ftvec/text/TfIdfUDF.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.ftvec.text;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.StringUtils;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.UDFType;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+
+@Description(name = "tfidf",
+        value = "_FUNC_(double termFrequency, long numDocs, const long 
totalNumDocs) "
+                + "- Return an TFIDF score in double.")
+@UDFType(deterministic = true, stateful = false)
+public final class TfIdfUDF extends GenericUDF {
+
+    private DoubleObjectInspector tfOI;
+    private LongObjectInspector numDocsOI;
+    private LongObjectInspector totalNumDocsOI;
+
+    @Nonnull
+    private final DoubleWritable result = new DoubleWritable();
+
+    @Override
+    public ObjectInspector initialize(ObjectInspector[] argOIs) throws 
UDFArgumentException {
+        if (argOIs.length != 3) {
+            throw new UDFArgumentLengthException(
+                "tfidf takes exactly three arguments but got " + 
argOIs.length);
+        }
+
+        this.tfOI = HiveUtils.asDoubleOI(argOIs[0]);
+        this.numDocsOI = HiveUtils.asLongOI(argOIs[1]);
+        this.totalNumDocsOI = HiveUtils.asLongOI(argOIs[2]);
+
+        return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
+    }
+
+
+    @Override
+    public Object evaluate(DeferredObject[] arguments) throws HiveException {
+        Object arg0 = getObject(arguments, 0);
+        Object arg1 = getObject(arguments, 1);
+        Object arg2 = getObject(arguments, 2);
+
+        double tf = PrimitiveObjectInspectorUtils.getDouble(arg0, tfOI);
+        long numDocs = PrimitiveObjectInspectorUtils.getLong(arg1, numDocsOI);
+        long totalNumDocs = PrimitiveObjectInspectorUtils.getLong(arg2, 
totalNumDocsOI);
+
+        // basic IDF
+        //    idf = log(N/n_t)
+        // IDF with smoothing
+        //    idf = log(N/(1+n_t))+1
+        //    idf = log(N/max(1,n_t))+1 -- about zero division        
+        double idf = (Math.log10(totalNumDocs / Math.max(1L, numDocs)) + 1.0d);
+        double tfidf = tf * idf;
+        result.set(tfidf);
+        return result;
+    }
+
+    @Nonnull
+    protected static Object getObject(@Nonnull final DeferredObject[] 
arguments,
+            @Nonnegative final int index) throws HiveException {
+        Object obj = arguments[index].get();
+        if (obj == null) {
+            throw new UDFArgumentException(String.format("%d-th argument MUST 
not be null", index));
+        }
+        return obj;
+    }
+
+    @Override
+    public String getDisplayString(String[] children) {
+        return "tfidf(" + StringUtils.join(children, ',') + ")";
+    }
+
+}
diff --git a/resources/ddl/define-all-as-permanent.hive 
b/resources/ddl/define-all-as-permanent.hive
index f995d55..c5f2669 100644
--- a/resources/ddl/define-all-as-permanent.hive
+++ b/resources/ddl/define-all-as-permanent.hive
@@ -349,6 +349,9 @@ CREATE FUNCTION tf as 
'hivemall.ftvec.text.TermFrequencyUDAF' USING JAR '${hivem
 DROP FUNCTION IF EXISTS bm25;
 CREATE FUNCTION bm25 as 'hivemall.ftvec.text.OkapiBM25UDF' USING JAR 
'${hivemall_jar}';
 
+DROP FUNCTION IF EXISTS tfidf;
+CREATE FUNCTION tfidf as 'hivemall.ftvec.text.TfIdfUDF' USING JAR 
'${hivemall_jar}';
+
 --------------------------
 -- Regression functions --
 --------------------------
@@ -920,3 +923,4 @@ CREATE FUNCTION xgboost_predict_one AS 
'hivemall.xgboost.XGBoostPredictOneUDTF'
 
 DROP FUNCTION xgboost_predict_triple;
 CREATE FUNCTION xgboost_predict_triple AS 
'hivemall.xgboost.XGBoostPredictTripleUDTF' USING JAR '${hivemall_jar}';
+
diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive
index bf9bc7c..b55c59b 100644
--- a/resources/ddl/define-all.hive
+++ b/resources/ddl/define-all.hive
@@ -345,6 +345,9 @@ create temporary function tf as 
'hivemall.ftvec.text.TermFrequencyUDAF';
 drop temporary function if exists bm25;
 create temporary function bm25 as 'hivemall.ftvec.text.OkapiBM25UDF';
 
+drop temporary function if exists tfidf;
+create temporary function tfidf as 'hivemall.ftvec.text.TfIdfUDF';
+
 --------------------------
 -- Regression functions --
 --------------------------
@@ -909,12 +912,3 @@ if(x>y,x,y);
 create temporary macro min2(x DOUBLE, y DOUBLE)
 if(x<y,x,y);
 
---------------------------
--- Statistics functions --
---------------------------
-
-create temporary macro idf(df_t DOUBLE, n_docs DOUBLE)
-log(10, n_docs / max2(1,df_t)) + 1.0;
-
-create temporary macro tfidf(tf FLOAT, df_t DOUBLE, n_docs DOUBLE)
-tf * (log(10, n_docs / max2(1,df_t)) + 1.0);
diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark
index 8529134..91c6350 100644
--- a/resources/ddl/define-all.spark
+++ b/resources/ddl/define-all.spark
@@ -348,6 +348,9 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION tf AS 
'hivemall.ftvec.text.TermFrequen
 sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS bm25")
 sqlContext.sql("CREATE TEMPORARY FUNCTION bm25 AS 
'hivemall.ftvec.text.OkapiBM25UDF'")
 
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS tfidf")
+sqlContext.sql("CREATE TEMPORARY FUNCTION tfidf AS 
'hivemall.ftvec.text.TfIdfUDF'")
+
 /**
  * Regression functions
  */
diff --git a/resources/ddl/define-macros.hive b/resources/ddl/define-macros.hive
index ff36a44..84c4e06 100644
--- a/resources/ddl/define-macros.hive
+++ b/resources/ddl/define-macros.hive
@@ -19,16 +19,6 @@ create temporary macro min2(x DOUBLE, y DOUBLE)
 if(x<y,x,y);
 
 --------------------------
--- Statistics functions --
---------------------------
-
-create temporary macro idf(df_t DOUBLE, n_docs DOUBLE)
-log(10, n_docs / max2(1,df_t)) + 1.0;
-
-create temporary macro tfidf(tf FLOAT, df_t DOUBLE, n_docs DOUBLE)
-tf * (log(10, n_docs / max2(1,df_t)) + 1.0);
-
---------------------------
 -- Evaluation functions --
 --------------------------
 

Reply via email to