Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141547506
--- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java ---
@@ -0,0 +1,364 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.UDTFWithOptions;
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+import hivemall.utils.collections.maps.OpenHashTable;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Primitives;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import
org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import
org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import java.util.List;
+import java.util.Arrays;
+import java.util.ArrayList;
+
+@Description(
+ name = "train_word2vec",
+ value = "_FUNC_(array<array<float | string>> negative_table,
array<int | string> doc [, const string options]) - Returns a prediction model")
+public class Word2VecUDTF extends UDTFWithOptions {
+ protected transient AbstractWord2VecModel model;
+ @Nonnegative
+ private float startingLR;
+ @Nonnegative
+ private long numTrainWords;
+ private OpenHashTable<String, Integer> word2index;
+
+ @Nonnegative
+ private int dim;
+ @Nonnegative
+ private int win;
+ @Nonnegative
+ private int neg;
+ @Nonnegative
+ private int iter;
+ private boolean skipgram;
+ private boolean isStringInput;
+
+ private Int2FloatOpenHashTable S;
+ private int[] aliasWordIds;
+
+ private ListObjectInspector negativeTableOI;
+ private ListObjectInspector negativeTableElementListOI;
+ private PrimitiveObjectInspector negativeTableElementOI;
+
+ private ListObjectInspector docOI;
+ private PrimitiveObjectInspector wordOI;
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs)
throws UDFArgumentException {
+ final int numArgs = argOIs.length;
+
+ if (numArgs != 3) {
+ throw new UDFArgumentException(getClass().getSimpleName()
+ + " takes 3 arguments: [, constant string options]: "
+ + Arrays.toString(argOIs));
+ }
+
+ processOptions(argOIs);
+
+ this.negativeTableOI = HiveUtils.asListOI(argOIs[0]);
+ this.negativeTableElementListOI =
HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector());
+ this.docOI = HiveUtils.asListOI(argOIs[1]);
+
+ this.isStringInput = HiveUtils.isStringListOI(argOIs[1]);
+
+ if (isStringInput) {
+ this.negativeTableElementOI =
HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector());
+ this.wordOI =
HiveUtils.asStringOI(docOI.getListElementObjectInspector());
+ } else {
+ this.negativeTableElementOI =
HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObjectInspector());
+ this.wordOI =
HiveUtils.asIntCompatibleOI(docOI.getListElementObjectInspector());
+ }
+
+ List<String> fieldNames = new ArrayList<>();
+ List<ObjectInspector> fieldOIs = new ArrayList<>();
+
+ fieldNames.add("word");
+ fieldNames.add("i");
+ fieldNames.add("wi");
+
+ if (isStringInput) {
+
fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
+ } else {
+
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ }
+
+
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+ this.model = null;
+ this.word2index = null;
+ this.S = null;
+ this.aliasWordIds = null;
+
+ return
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ @Override
+ public void process(Object[] args) throws HiveException {
+ if (model == null) {
+ parseNegativeTable(args[0]);
+ this.model = createModel();
+ }
+
+ List<?> rawDoc = docOI.getList(args[1]);
+
+ // parse rawDoc
+ final int docLength = rawDoc.size();
+ final int[] doc = new int[docLength];
+ if (isStringInput) {
+ for (int i = 0; i < docLength; i++) {
+ doc[i] =
getWordId(PrimitiveObjectInspectorUtils.getString(rawDoc.get(i), wordOI));
+ }
+ } else {
+ for (int i = 0; i < docLength; i++) {
+ doc[i] =
PrimitiveObjectInspectorUtils.getInt(rawDoc.get(i), wordOI);
+ }
+ }
+
+ model.trainOnDoc(doc);
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("n", "numTrainWords", true,
+ "The number of words in the documents. It is used to update
learning rate");
+ opts.addOption("dim", "dimension", true, "The number of vector
dimension [default: 100]");
+ opts.addOption("win", "window", true, "Context window size
[default: 5]");
+ opts.addOption("neg", "negative", true,
+ "The number of negative sampled words per word [default: 5]");
+ opts.addOption("iter", "iteration", true, "The number of
iterations [default: 5]");
+ opts.addOption("model", "modelName", true,
+ "The model name of word2vec: skipgram or cbow [default:
skipgram]");
+ opts.addOption(
+ "lr",
--- End diff --
consistent naming `eta0`, `learningRate` for the initial learning rate.
---