Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/117#discussion_r141026059
--- Diff: core/src/main/java/hivemall/recommend/SlimUDTF.java ---
@@ -0,0 +1,750 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.recommend;
+
+import hivemall.UDTFWithOptions;
+import hivemall.annotations.VisibleForTesting;
+import hivemall.common.ConversionState;
+import hivemall.math.matrix.sparse.DoKFloatMatrix;
+import hivemall.math.vector.VectorProcedure;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+import hivemall.utils.collections.maps.IntOpenHashTable;
+import hivemall.utils.collections.maps.IntOpenHashTable.IMapIterator;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.io.FileUtils;
+import hivemall.utils.io.NioStatefullSegment;
+import hivemall.utils.lang.NumberUtils;
+import hivemall.utils.lang.Primitives;
+import hivemall.utils.lang.SizeOf;
+import hivemall.utils.lang.mutable.MutableDouble;
+import hivemall.utils.lang.mutable.MutableInt;
+import hivemall.utils.lang.mutable.MutableObject;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import
org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import
org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapred.Counters;
+import org.apache.hadoop.mapred.Reporter;
+
+/**
+ * Sparse Linear Methods (SLIM) for Top-N Recommender Systems.
+ *
+ * <pre>
+ * Xia Ning and George Karypis, SLIM: Sparse Linear Methods for Top-N
Recommender Systems, Proc. ICDM, 2011.
+ * </pre>
+ */
+@Description(
+ name = "train_slim",
+ value = "_FUNC_( int i, map<int, double> r_i, map<int, map<int,
double>> topKRatesOfI, int j, map<int, double> r_j [, constant string options])
"
+ + "- Returns row index, column index and non-zero weight
value of prediction model")
+public class SlimUDTF extends UDTFWithOptions {
+ private static final Log logger = LogFactory.getLog(SlimUDTF.class);
+
+ //--------------------------------------------
+ // intput OIs
+
+ private PrimitiveObjectInspector itemIOI;
+ private PrimitiveObjectInspector itemJOI;
+ private MapObjectInspector riOI;
+ private MapObjectInspector rjOI;
+
+ private MapObjectInspector knnItemsOI;
+ private PrimitiveObjectInspector knnItemsKeyOI;
+ private MapObjectInspector knnItemsValueOI;
+ private PrimitiveObjectInspector knnItemsValueKeyOI;
+ private PrimitiveObjectInspector knnItemsValueValueOI;
+
+ private PrimitiveObjectInspector riKeyOI;
+ private PrimitiveObjectInspector riValueOI;
+
+ private PrimitiveObjectInspector rjKeyOI;
+ private PrimitiveObjectInspector rjValueOI;
+
+ //--------------------------------------------
+ // hyperparameters
+
+ private double l1;
+ private double l2;
+ private int numIterations;
+
+ //--------------------------------------------
+ // model parameters and else
+
+ /** item-item weight matrix */
+ private transient DoKFloatMatrix _weightMatrix;
+
+ //--------------------------------------------
+ // caching for each item i
+
+ private int _previousItemId;
+
+ @Nullable
+ private transient Int2FloatOpenHashTable _ri;
+ @Nullable
+ private transient IntOpenHashTable<Int2FloatOpenHashTable> _kNNi;
+ /** The number of elements in kNNi */
+ @Nullable
+ private transient MutableInt _nnzKNNi;
+
+ //--------------------------------------------
+ // variables for iteration supports
+
+ /** item-user matrix holding the input data */
+ @Nullable
+ private transient DoKFloatMatrix _dataMatrix;
+
+ // used to store KNN data into temporary file for iterative training
+ private transient NioStatefullSegment _fileIO;
+ private transient ByteBuffer _inputBuf;
+
+ private ConversionState _cvState;
+ private long _observedTrainingExamples;
+
+ //--------------------------------------------
+
+ public SlimUDTF() {}
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs)
throws UDFArgumentException {
+ final int numArgs = argOIs.length;
+
+ if (numArgs == 1 && HiveUtils.isConstString(argOIs[0])) {// for
-help option
+ String rawArgs = HiveUtils.getConstString(argOIs[0]);
+ parseOptions(rawArgs);
+ }
+
+ if (numArgs != 5 && numArgs != 6) {
+ throw new UDFArgumentException(
+ "_FUNC_ takes 5 or 6 arguments: int i, map<int, double>
r_i, map<int, map<int, double>> topKRatesOfI, int j, map<int, double> r_j [,
constant string options]: "
+ + Arrays.toString(argOIs));
+ }
+
+ this.itemIOI = HiveUtils.asIntCompatibleOI(argOIs[0]);
+
+ this.riOI = HiveUtils.asMapOI(argOIs[1]);
+ this.riKeyOI =
HiveUtils.asIntCompatibleOI((riOI.getMapKeyObjectInspector()));
+ this.riValueOI =
HiveUtils.asPrimitiveObjectInspector((riOI.getMapValueObjectInspector()));
+
+ this.knnItemsOI = HiveUtils.asMapOI(argOIs[2]);
+ this.knnItemsKeyOI =
HiveUtils.asIntCompatibleOI(knnItemsOI.getMapKeyObjectInspector());
+ this.knnItemsValueOI =
HiveUtils.asMapOI(knnItemsOI.getMapValueObjectInspector());
+ this.knnItemsValueKeyOI =
HiveUtils.asIntCompatibleOI(knnItemsValueOI.getMapKeyObjectInspector());
+ this.knnItemsValueValueOI =
HiveUtils.asDoubleCompatibleOI(knnItemsValueOI.getMapValueObjectInspector());
+
+ this.itemJOI = HiveUtils.asIntCompatibleOI(argOIs[3]);
+
+ this.rjOI = HiveUtils.asMapOI(argOIs[4]);
+ this.rjKeyOI =
HiveUtils.asIntCompatibleOI((rjOI.getMapKeyObjectInspector()));
+ this.rjValueOI =
HiveUtils.asPrimitiveObjectInspector((rjOI.getMapValueObjectInspector()));
+
+ processOptions(argOIs);
+
+ this._observedTrainingExamples = 0L;
+ this._previousItemId = Integer.MIN_VALUE;
+ this._weightMatrix = null;
+ this._dataMatrix = null;
+
+ List<String> fieldNames = new ArrayList<>();
+ List<ObjectInspector> fieldOIs = new ArrayList<>();
+
+ fieldNames.add("j");
+
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldNames.add("nn");
+
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldNames.add("w");
+
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+ return
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("l1", "l1coefficient", true,
+ "Coefficient for l1 regularizer [default: 0.001]");
+ opts.addOption("l2", "l2coefficient", true,
+ "Coefficient for l2 regularizer [default: 0.0005]");
+ opts.addOption("iters", "iterations", true,
+ "The number of iterations for coordinate descent [default:
40]");
+ opts.addOption("disable_cv", "disable_cvtest", false,
+ "Whether to disable convergence check [default: enabled]");
+ opts.addOption("cv_rate", "convergence_rate", true,
+ "Threshold to determine convergence [default: 0.005]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs)
+ throws UDFArgumentException {
+ CommandLine cl = null;
+ double l1 = 0.001d;
+ double l2 = 0.0005d;
+ int numIterations = 40;
+ boolean conversionCheck = true;
+ double cv_rate = 0.005d;
+
+ if (argOIs.length >= 6) {
+ String rawArgs = HiveUtils.getConstString(argOIs[5]);
+ cl = parseOptions(rawArgs);
+
+ l1 = Primitives.parseDouble(cl.getOptionValue("l1"), l1);
+ if (l1 < 0.d) {
+ throw new UDFArgumentException("Argument `double l1` must
be non-negative: " + l1);
+ }
+
+ l2 = Primitives.parseDouble(cl.getOptionValue("l2"), l2);
+ if (l2 < 0.d) {
+ throw new UDFArgumentException("Argument `double l2` must
be non-negative: " + l2);
+ }
+
+ numIterations =
Primitives.parseInt(cl.getOptionValue("iters"), numIterations);
+ if (numIterations <= 0) {
+ throw new UDFArgumentException("Argument `int iters` must
be greater than 0: "
+ + numIterations);
+ }
+
+ conversionCheck = !cl.hasOption("disable_cvtest");
+
+ cv_rate = Primitives.parseDouble(cl.getOptionValue("cv_rate"),
cv_rate);
+ if (cv_rate <= 0) {
+ throw new UDFArgumentException(
+ "Argument `double cv_rate` must be greater than 0.0: "
+ cv_rate);
+ }
+ }
+
+ this.l1 = l1;
+ this.l2 = l2;
+ this.numIterations = numIterations;
+ this._cvState = new ConversionState(conversionCheck, cv_rate);
+
+ return cl;
+ }
+
+ @Override
+ public void process(@Nonnull Object[] args) throws HiveException {
+ if (_weightMatrix == null) {// initialize variables
+ this._weightMatrix = new DoKFloatMatrix();
+ if (numIterations >= 2) {
+ this._dataMatrix = new DoKFloatMatrix();
+ }
+ this._nnzKNNi = new MutableInt();
+ }
+
+ final int itemI = PrimitiveObjectInspectorUtils.getInt(args[0],
itemIOI);
+
+ if (itemI != _previousItemId || _ri == null) {
+ // cache Ri and kNNi
+ this._ri = int2floatMap(itemI, riOI.getMap(args[1]), riKeyOI,
riValueOI, _dataMatrix,
+ _ri);
+ this._kNNi = getKNNi(args[2], knnItemsOI, knnItemsKeyOI,
knnItemsValueOI,
+ knnItemsValueKeyOI, knnItemsValueValueOI, _kNNi, _nnzKNNi);
+
+ final int numKNNItems = _nnzKNNi.getValue();
+ if (numIterations >= 2 && numKNNItems >= 1) {
+ recordTrainingInput(itemI, _kNNi, numKNNItems);
+ }
+ this._previousItemId = itemI;
+ }
+
+ int itemJ = PrimitiveObjectInspectorUtils.getInt(args[3], itemJOI);
+ Int2FloatOpenHashTable rj = int2floatMap(itemJ,
rjOI.getMap(args[4]), rjKeyOI, rjValueOI,
+ _dataMatrix);
+
+ train(itemI, _ri, _kNNi, itemJ, rj);
+ _observedTrainingExamples++;
+ }
+
+ private void recordTrainingInput(final int itemI,
+ @Nonnull final IntOpenHashTable<Int2FloatOpenHashTable>
knnItems, final int numKNNItems)
+ throws HiveException {
+ ByteBuffer buf = this._inputBuf;
+ NioStatefullSegment dst = this._fileIO;
+
+ if (buf == null) {
+ // invoke only at task node (initialize is also invoked in
compilation)
+ final File file;
+ try {
+ file = File.createTempFile("hivemall_slim", ".sgmt"); //
to save KNN data
+ file.deleteOnExit();
+ if (!file.canWrite()) {
+ throw new UDFArgumentException("Cannot write a
temporary file: "
+ + file.getAbsolutePath());
+ }
+ } catch (IOException ioe) {
+ throw new UDFArgumentException(ioe);
+ }
+
+ this._inputBuf = buf = ByteBuffer.allocateDirect(8 * 1024 *
1024); // 8MB
+ this._fileIO = dst = new NioStatefullSegment(file, false);
+ }
+
+ int recordBytes = SizeOf.INT + SizeOf.INT + SizeOf.INT * 2 *
knnItems.size()
+ + (SizeOf.INT + SizeOf.FLOAT) * numKNNItems;
+ int requiredBytes = SizeOf.INT + recordBytes; // need to allocate
space for "recordBytes" itself
+
+ int remain = buf.remaining();
+ if (remain < requiredBytes) {
+ writeBuffer(buf, dst);
+ }
+
+ buf.putInt(recordBytes);
+ buf.putInt(itemI);
+ buf.putInt(knnItems.size());
+
+ final IMapIterator<Int2FloatOpenHashTable> entries =
knnItems.entries();
+ while (entries.next() != -1) {
+ int user = entries.getKey();
+ buf.putInt(user);
+
+ Int2FloatOpenHashTable ru = entries.getValue();
+ buf.putInt(ru.size());
+ final Int2FloatOpenHashTable.IMapIterator itor = ru.entries();
+ while (itor.next() != -1) {
+ buf.putInt(itor.getKey());
+ buf.putFloat(itor.getValue());
+ }
+ }
+ }
+
+ private static void writeBuffer(@Nonnull final ByteBuffer srcBuf,
+ @Nonnull final NioStatefullSegment dst) throws HiveException {
+ srcBuf.flip();
+ try {
+ dst.write(srcBuf);
+ } catch (IOException e) {
+ throw new HiveException("Exception causes while writing a
buffer to file", e);
+ }
+ srcBuf.clear();
+ }
+
+ private void train(final int itemI, @Nonnull final
Int2FloatOpenHashTable ri,
+ @Nonnull final IntOpenHashTable<Int2FloatOpenHashTable> kNNi,
final int itemJ,
+ @Nonnull final Int2FloatOpenHashTable rj) {
+ final DoKFloatMatrix W = _weightMatrix;
+
+ final int N = rj.size();
+ if (N == 0) {
+ return;
+ }
+
+ double gradSum = 0.d;
+ double rateSum = 0.d;
+ double lossSum = 0.d;
+
+ final Int2FloatOpenHashTable.IMapIterator itor = rj.entries();
+ while (itor.next() != -1) {
+ int user = itor.getKey();
+ double ruj = itor.getValue();
+ double rui = ri.get(user, 0.f);
+
+ double eui = rui - predict(user, itemI, kNNi, itemJ, W);
+ gradSum += ruj * eui;
+ rateSum += ruj * ruj;
+ lossSum += eui * eui;
+ }
+
+ gradSum /= N;
+ rateSum /= N;
+
+ double wij = W.get(itemI, itemJ, 0.d);
+ double loss = lossSum / N + 0.5d * l2 * wij * wij + l1 * wij;
+ _cvState.incrLoss(loss);
+
+ W.set(itemI, itemJ, getUpdateTerm(gradSum, rateSum, l1, l2));
+ }
+
+ private void train(final int itemI,
+ @Nonnull final IntOpenHashTable<Int2FloatOpenHashTable>
knnItems, final int itemJ) {
+ final DoKFloatMatrix A = _dataMatrix;
+ final DoKFloatMatrix W = _weightMatrix;
+
+ final int N = A.numColumns(itemJ);
+ if (N == 0) {
+ return;
+ }
+
+ final MutableDouble mutableGradSum = new MutableDouble(0.d);
+ final MutableDouble mutableRateSum = new MutableDouble(0.d);
+ final MutableDouble mutableLossSum = new MutableDouble(0.d);
+
+ A.eachNonZeroInRow(itemJ, new VectorProcedure() {
--- End diff --
DoKMatrix A should be converted to CSRMatrix for this access pattern.
---