Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543986
--- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java ---
@@ -0,0 +1,364 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.UDTFWithOptions;
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+import hivemall.utils.collections.maps.OpenHashTable;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Primitives;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import
org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import
org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import java.util.List;
+import java.util.Arrays;
+import java.util.ArrayList;
+
+@Description(
+ name = "train_word2vec",
+ value = "_FUNC_(array<array<float | string>> negative_table,
array<int | string> doc [, const string options]) - Returns a prediction model")
+public class Word2VecUDTF extends UDTFWithOptions {
+ protected transient AbstractWord2VecModel model;
+ @Nonnegative
+ private float startingLR;
+ @Nonnegative
+ private long numTrainWords;
+ private OpenHashTable<String, Integer> word2index;
+
+ @Nonnegative
+ private int dim;
+ @Nonnegative
+ private int win;
+ @Nonnegative
+ private int neg;
+ @Nonnegative
+ private int iter;
+ private boolean skipgram;
+ private boolean isStringInput;
+
+ private Int2FloatOpenHashTable S;
+ private int[] aliasWordIds;
+
+ private ListObjectInspector negativeTableOI;
+ private ListObjectInspector negativeTableElementListOI;
+ private PrimitiveObjectInspector negativeTableElementOI;
+
+ private ListObjectInspector docOI;
+ private PrimitiveObjectInspector wordOI;
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs)
throws UDFArgumentException {
+ final int numArgs = argOIs.length;
+
+ if (numArgs != 3) {
+ throw new UDFArgumentException(getClass().getSimpleName()
+ + " takes 3 arguments: [, constant string options]: "
+ + Arrays.toString(argOIs));
+ }
+
+ processOptions(argOIs);
+
+ this.negativeTableOI = HiveUtils.asListOI(argOIs[0]);
+ this.negativeTableElementListOI =
HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector());
+ this.docOI = HiveUtils.asListOI(argOIs[1]);
+
+ this.isStringInput = HiveUtils.isStringListOI(argOIs[1]);
+
+ if (isStringInput) {
+ this.negativeTableElementOI =
HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector());
+ this.wordOI =
HiveUtils.asStringOI(docOI.getListElementObjectInspector());
+ } else {
+ this.negativeTableElementOI =
HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObjectInspector());
+ this.wordOI =
HiveUtils.asIntCompatibleOI(docOI.getListElementObjectInspector());
+ }
+
+ List<String> fieldNames = new ArrayList<>();
+ List<ObjectInspector> fieldOIs = new ArrayList<>();
+
+ fieldNames.add("word");
+ fieldNames.add("i");
+ fieldNames.add("wi");
+
+ if (isStringInput) {
+
fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
+ } else {
+
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ }
+
+
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+ this.model = null;
+ this.word2index = null;
+ this.S = null;
+ this.aliasWordIds = null;
+
+ return
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ @Override
+ public void process(Object[] args) throws HiveException {
+ if (model == null) {
+ parseNegativeTable(args[0]);
+ this.model = createModel();
+ }
+
+ List<?> rawDoc = docOI.getList(args[1]);
+
+ // parse rawDoc
+ final int docLength = rawDoc.size();
+ final int[] doc = new int[docLength];
+ if (isStringInput) {
+ for (int i = 0; i < docLength; i++) {
+ doc[i] =
getWordId(PrimitiveObjectInspectorUtils.getString(rawDoc.get(i), wordOI));
+ }
+ } else {
+ for (int i = 0; i < docLength; i++) {
+ doc[i] =
PrimitiveObjectInspectorUtils.getInt(rawDoc.get(i), wordOI);
--- End diff --
`rawDoc.get(i)` may return null.
---