Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/111#discussion_r136904245
--- Diff: core/src/main/java/hivemall/recommend/SlimUDTF.java ---
@@ -0,0 +1,625 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.recommend;
+
+
+import hivemall.UDTFWithOptions;
+import hivemall.annotations.VisibleForTesting;
+import hivemall.common.ConversionState;
+import hivemall.math.matrix.sparse.DoKMatrix;
+import hivemall.math.vector.VectorProcedure;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.io.FileUtils;
+import hivemall.utils.io.NioStatefullSegment;
+import hivemall.utils.lang.NumberUtils;
+import hivemall.utils.lang.Primitives;
+import hivemall.utils.lang.SizeOf;
+import hivemall.utils.lang.mutable.MutableDouble;
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.*;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.*;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapred.Counters;
+import org.apache.hadoop.mapred.Reporter;
+
+import javax.annotation.Nonnull;
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.*;
+
+
+public class SlimUDTF extends UDTFWithOptions {
+ private static final Log logger = LogFactory.getLog(SlimUDTF.class);
+
+ private double l1;
+ private double l2;
+ private int numIterations;
+ private int previousItemId;
+
+ private transient DoKMatrix weightMatrix; // item-item weight matrix
+ private transient DoKMatrix dataMatrix; // item-user matrix
+
+ private PrimitiveObjectInspector itemIOI;
+ private PrimitiveObjectInspector itemJOI;
+ private MapObjectInspector riOI;
+ private MapObjectInspector rjOI;
+
+ private MapObjectInspector knnItemsOI;
+ private PrimitiveObjectInspector knnItemsKeyOI;
+ private MapObjectInspector knnItemsValueOI;
+ private PrimitiveObjectInspector knnItemsValueKeyOI;
+ private PrimitiveObjectInspector knnItemsValueValueOI;
+
+ private PrimitiveObjectInspector riKeyOI;
+ private PrimitiveObjectInspector riValueOI;
+
+ private PrimitiveObjectInspector rjKeyOI;
+ private PrimitiveObjectInspector rjValueOI;
+
+ // used to store KNN data into temporary file for iterative training
+ private NioStatefullSegment fileIO;
+ private ByteBuffer inputBuf;
+
+ private ConversionState cvState;
+ private long observedTrainingExamples;
+
+ public SlimUDTF() {}
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs)
throws UDFArgumentException {
+ final int numArgs = argOIs.length;
+ if (numArgs != 5 && numArgs != 6) {
+ throw new UDFArgumentException(
+ "_FUNC_ takes arguments: int i, map<int, double> r_i,
map<int, map<int, double>> topKRatesOfI, int j, map<int, double> r_j, [,
constant string options]");
+ }
+
+ this.itemIOI = HiveUtils.asIntCompatibleOI(argOIs[0]);
+
+ this.riOI = HiveUtils.asMapOI(argOIs[1]);
+ this.riKeyOI =
HiveUtils.asIntCompatibleOI((this.riOI.getMapKeyObjectInspector()));
+ this.riValueOI =
HiveUtils.asPrimitiveObjectInspector((this.riOI.getMapValueObjectInspector()));
+
+ this.knnItemsOI = HiveUtils.asMapOI(argOIs[2]);
+ this.knnItemsKeyOI =
HiveUtils.asIntCompatibleOI(knnItemsOI.getMapKeyObjectInspector());
+ this.knnItemsValueOI =
HiveUtils.asMapOI(knnItemsOI.getMapValueObjectInspector());
+ this.knnItemsValueKeyOI =
HiveUtils.asIntCompatibleOI(knnItemsValueOI.getMapKeyObjectInspector());
+ this.knnItemsValueValueOI =
HiveUtils.asDoubleCompatibleOI(knnItemsValueOI.getMapValueObjectInspector());
+
+ this.itemJOI = HiveUtils.asIntCompatibleOI(argOIs[3]);
+
+ this.rjOI = HiveUtils.asMapOI(argOIs[4]);
+ this.rjKeyOI =
HiveUtils.asIntCompatibleOI((this.rjOI.getMapKeyObjectInspector()));
+ this.rjValueOI =
HiveUtils.asPrimitiveObjectInspector((this.rjOI.getMapValueObjectInspector()));
+
+ processOptions(argOIs);
+
+ List<String> fieldNames = new ArrayList<>();
+ List<ObjectInspector> fieldOIs = new ArrayList<>();
+
+ fieldNames.add("i");
+ fieldNames.add("j");
+ fieldNames.add("wij");
+
+
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+
+ this.observedTrainingExamples = 0L;
+ this.previousItemId = -2147483648;
+
+ this.dataMatrix = null;
+ this.weightMatrix = null;
+
+ return
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("l1", "l1coefficient", true,
+ "Coefficient for l1 regularizer [default: 0.01]");
+ opts.addOption("l2", "l2coefficient", true,
+ "Coefficient for l2 regularizer [default: 0.01]");
+ opts.addOption("numIterations", "iteration", true,
+ "The number of iterations for coordinate descent [default:
40]");
+ opts.addOption("disable_cv", "disable_cvtest", false,
+ "Whether to disable convergence check [default: enabled]");
+ opts.addOption("cv_rate", "convergence_rate", true,
+ "Threshold to determine convergence [default: 0.005]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(ObjectInspector[] argOIs) throws
UDFArgumentException {
+ CommandLine cl = null;
+ double l1 = 0.01d;
+ double l2 = 0.01d;
+ int numIterations = 40;
+ boolean conversionCheck = true;
+ double cv_rate = 0.005d;
+
+ if (argOIs.length >= 6) {
+ String rawArgs = HiveUtils.getConstString(argOIs[5]);
+ cl = parseOptions(rawArgs);
+
+ l1 = Primitives.parseDouble(cl.getOptionValue("l1"), l1);
+ if (l1 < 0.d || l1 > 1.d) {
+ throw new UDFArgumentException("Argument `double l1` must
be within [0., 1.]: "
+ + l1);
+ }
+
+ l2 = Primitives.parseDouble(cl.getOptionValue("l2"), l2);
+ if (l2 < 0.d || l2 > 1.d) {
--- End diff --
` if (l2 < 0.d) {` is enough.
---