Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/117#discussion_r141026059
  
    --- Diff: core/src/main/java/hivemall/recommend/SlimUDTF.java ---
    @@ -0,0 +1,750 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.recommend;
    +
    +import hivemall.UDTFWithOptions;
    +import hivemall.annotations.VisibleForTesting;
    +import hivemall.common.ConversionState;
    +import hivemall.math.matrix.sparse.DoKFloatMatrix;
    +import hivemall.math.vector.VectorProcedure;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +import hivemall.utils.collections.maps.IntOpenHashTable;
    +import hivemall.utils.collections.maps.IntOpenHashTable.IMapIterator;
    +import hivemall.utils.hadoop.HiveUtils;
    +import hivemall.utils.io.FileUtils;
    +import hivemall.utils.io.NioStatefullSegment;
    +import hivemall.utils.lang.NumberUtils;
    +import hivemall.utils.lang.Primitives;
    +import hivemall.utils.lang.SizeOf;
    +import hivemall.utils.lang.mutable.MutableDouble;
    +import hivemall.utils.lang.mutable.MutableInt;
    +import hivemall.utils.lang.mutable.MutableObject;
    +
    +import java.io.File;
    +import java.io.IOException;
    +import java.nio.ByteBuffer;
    +import java.util.ArrayList;
    +import java.util.Arrays;
    +import java.util.HashSet;
    +import java.util.List;
    +import java.util.Map;
    +import java.util.Set;
    +
    +import javax.annotation.Nonnull;
    +import javax.annotation.Nullable;
    +
    +import org.apache.commons.cli.CommandLine;
    +import org.apache.commons.cli.Options;
    +import org.apache.commons.logging.Log;
    +import org.apache.commons.logging.LogFactory;
    +import org.apache.hadoop.hive.ql.exec.Description;
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
    +import org.apache.hadoop.hive.ql.metadata.HiveException;
    +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
    +import 
org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
    +import 
org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
    +import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
    +import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
    +import org.apache.hadoop.io.FloatWritable;
    +import org.apache.hadoop.io.IntWritable;
    +import org.apache.hadoop.mapred.Counters;
    +import org.apache.hadoop.mapred.Reporter;
    +
    +/**
    + * Sparse Linear Methods (SLIM) for Top-N Recommender Systems.
    + *
    + * <pre>
    + * Xia Ning and George Karypis, SLIM: Sparse Linear Methods for Top-N 
Recommender Systems, Proc. ICDM, 2011.
    + * </pre>
    + */
    +@Description(
    +        name = "train_slim",
    +        value = "_FUNC_( int i, map<int, double> r_i, map<int, map<int, 
double>> topKRatesOfI, int j, map<int, double> r_j [, constant string options]) 
"
    +                + "- Returns row index, column index and non-zero weight 
value of prediction model")
    +public class SlimUDTF extends UDTFWithOptions {
    +    private static final Log logger = LogFactory.getLog(SlimUDTF.class);
    +
    +    //--------------------------------------------
    +    // intput OIs
    +
    +    private PrimitiveObjectInspector itemIOI;
    +    private PrimitiveObjectInspector itemJOI;
    +    private MapObjectInspector riOI;
    +    private MapObjectInspector rjOI;
    +
    +    private MapObjectInspector knnItemsOI;
    +    private PrimitiveObjectInspector knnItemsKeyOI;
    +    private MapObjectInspector knnItemsValueOI;
    +    private PrimitiveObjectInspector knnItemsValueKeyOI;
    +    private PrimitiveObjectInspector knnItemsValueValueOI;
    +
    +    private PrimitiveObjectInspector riKeyOI;
    +    private PrimitiveObjectInspector riValueOI;
    +
    +    private PrimitiveObjectInspector rjKeyOI;
    +    private PrimitiveObjectInspector rjValueOI;
    +
    +    //--------------------------------------------
    +    // hyperparameters
    +
    +    private double l1;
    +    private double l2;
    +    private int numIterations;
    +
    +    //--------------------------------------------
    +    // model parameters and else
    +
    +    /** item-item weight matrix */
    +    private transient DoKFloatMatrix _weightMatrix;
    +
    +    //--------------------------------------------
    +    // caching for each item i
    +
    +    private int _previousItemId;
    +
    +    @Nullable
    +    private transient Int2FloatOpenHashTable _ri;
    +    @Nullable
    +    private transient IntOpenHashTable<Int2FloatOpenHashTable> _kNNi;
    +    /** The number of elements in kNNi */
    +    @Nullable
    +    private transient MutableInt _nnzKNNi;
    +
    +    //--------------------------------------------
    +    // variables for iteration supports
    +
    +    /** item-user matrix holding the input data */
    +    @Nullable
    +    private transient DoKFloatMatrix _dataMatrix;
    +
    +    // used to store KNN data into temporary file for iterative training
    +    private transient NioStatefullSegment _fileIO;
    +    private transient ByteBuffer _inputBuf;
    +
    +    private ConversionState _cvState;
    +    private long _observedTrainingExamples;
    +
    +    //--------------------------------------------
    +
    +    public SlimUDTF() {}
    +
    +    @Override
    +    public StructObjectInspector initialize(ObjectInspector[] argOIs) 
throws UDFArgumentException {
    +        final int numArgs = argOIs.length;
    +
    +        if (numArgs == 1 && HiveUtils.isConstString(argOIs[0])) {// for 
-help option
    +            String rawArgs = HiveUtils.getConstString(argOIs[0]);
    +            parseOptions(rawArgs);
    +        }
    +
    +        if (numArgs != 5 && numArgs != 6) {
    +            throw new UDFArgumentException(
    +                "_FUNC_ takes 5 or 6 arguments: int i, map<int, double> 
r_i, map<int, map<int, double>> topKRatesOfI, int j, map<int, double> r_j [, 
constant string options]: "
    +                        + Arrays.toString(argOIs));
    +        }
    +
    +        this.itemIOI = HiveUtils.asIntCompatibleOI(argOIs[0]);
    +
    +        this.riOI = HiveUtils.asMapOI(argOIs[1]);
    +        this.riKeyOI = 
HiveUtils.asIntCompatibleOI((riOI.getMapKeyObjectInspector()));
    +        this.riValueOI = 
HiveUtils.asPrimitiveObjectInspector((riOI.getMapValueObjectInspector()));
    +
    +        this.knnItemsOI = HiveUtils.asMapOI(argOIs[2]);
    +        this.knnItemsKeyOI = 
HiveUtils.asIntCompatibleOI(knnItemsOI.getMapKeyObjectInspector());
    +        this.knnItemsValueOI = 
HiveUtils.asMapOI(knnItemsOI.getMapValueObjectInspector());
    +        this.knnItemsValueKeyOI = 
HiveUtils.asIntCompatibleOI(knnItemsValueOI.getMapKeyObjectInspector());
    +        this.knnItemsValueValueOI = 
HiveUtils.asDoubleCompatibleOI(knnItemsValueOI.getMapValueObjectInspector());
    +
    +        this.itemJOI = HiveUtils.asIntCompatibleOI(argOIs[3]);
    +
    +        this.rjOI = HiveUtils.asMapOI(argOIs[4]);
    +        this.rjKeyOI = 
HiveUtils.asIntCompatibleOI((rjOI.getMapKeyObjectInspector()));
    +        this.rjValueOI = 
HiveUtils.asPrimitiveObjectInspector((rjOI.getMapValueObjectInspector()));
    +
    +        processOptions(argOIs);
    +
    +        this._observedTrainingExamples = 0L;
    +        this._previousItemId = Integer.MIN_VALUE;
    +        this._weightMatrix = null;
    +        this._dataMatrix = null;
    +
    +        List<String> fieldNames = new ArrayList<>();
    +        List<ObjectInspector> fieldOIs = new ArrayList<>();
    +
    +        fieldNames.add("j");
    +        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        fieldNames.add("nn");
    +        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        fieldNames.add("w");
    +        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
    +
    +        return 
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    +    }
    +
    +    @Override
    +    protected Options getOptions() {
    +        Options opts = new Options();
    +        opts.addOption("l1", "l1coefficient", true,
    +            "Coefficient for l1 regularizer [default: 0.001]");
    +        opts.addOption("l2", "l2coefficient", true,
    +            "Coefficient for l2 regularizer [default: 0.0005]");
    +        opts.addOption("iters", "iterations", true,
    +            "The number of iterations for coordinate descent [default: 
40]");
    +        opts.addOption("disable_cv", "disable_cvtest", false,
    +            "Whether to disable convergence check [default: enabled]");
    +        opts.addOption("cv_rate", "convergence_rate", true,
    +            "Threshold to determine convergence [default: 0.005]");
    +        return opts;
    +    }
    +
    +    @Override
    +    protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs)
    +            throws UDFArgumentException {
    +        CommandLine cl = null;
    +        double l1 = 0.001d;
    +        double l2 = 0.0005d;
    +        int numIterations = 40;
    +        boolean conversionCheck = true;
    +        double cv_rate = 0.005d;
    +
    +        if (argOIs.length >= 6) {
    +            String rawArgs = HiveUtils.getConstString(argOIs[5]);
    +            cl = parseOptions(rawArgs);
    +
    +            l1 = Primitives.parseDouble(cl.getOptionValue("l1"), l1);
    +            if (l1 < 0.d) {
    +                throw new UDFArgumentException("Argument `double l1` must 
be non-negative: " + l1);
    +            }
    +
    +            l2 = Primitives.parseDouble(cl.getOptionValue("l2"), l2);
    +            if (l2 < 0.d) {
    +                throw new UDFArgumentException("Argument `double l2` must 
be non-negative: " + l2);
    +            }
    +
    +            numIterations = 
Primitives.parseInt(cl.getOptionValue("iters"), numIterations);
    +            if (numIterations <= 0) {
    +                throw new UDFArgumentException("Argument `int iters` must 
be greater than 0: "
    +                        + numIterations);
    +            }
    +
    +            conversionCheck = !cl.hasOption("disable_cvtest");
    +
    +            cv_rate = Primitives.parseDouble(cl.getOptionValue("cv_rate"), 
cv_rate);
    +            if (cv_rate <= 0) {
    +                throw new UDFArgumentException(
    +                    "Argument `double cv_rate` must be greater than 0.0: " 
+ cv_rate);
    +            }
    +        }
    +
    +        this.l1 = l1;
    +        this.l2 = l2;
    +        this.numIterations = numIterations;
    +        this._cvState = new ConversionState(conversionCheck, cv_rate);
    +
    +        return cl;
    +    }
    +
    +    @Override
    +    public void process(@Nonnull Object[] args) throws HiveException {
    +        if (_weightMatrix == null) {// initialize variables
    +            this._weightMatrix = new DoKFloatMatrix();
    +            if (numIterations >= 2) {
    +                this._dataMatrix = new DoKFloatMatrix();
    +            }
    +            this._nnzKNNi = new MutableInt();
    +        }
    +
    +        final int itemI = PrimitiveObjectInspectorUtils.getInt(args[0], 
itemIOI);
    +
    +        if (itemI != _previousItemId || _ri == null) {
    +            // cache Ri and kNNi
    +            this._ri = int2floatMap(itemI, riOI.getMap(args[1]), riKeyOI, 
riValueOI, _dataMatrix,
    +                _ri);
    +            this._kNNi = getKNNi(args[2], knnItemsOI, knnItemsKeyOI, 
knnItemsValueOI,
    +                knnItemsValueKeyOI, knnItemsValueValueOI, _kNNi, _nnzKNNi);
    +
    +            final int numKNNItems = _nnzKNNi.getValue();
    +            if (numIterations >= 2 && numKNNItems >= 1) {
    +                recordTrainingInput(itemI, _kNNi, numKNNItems);
    +            }
    +            this._previousItemId = itemI;
    +        }
    +
    +        int itemJ = PrimitiveObjectInspectorUtils.getInt(args[3], itemJOI);
    +        Int2FloatOpenHashTable rj = int2floatMap(itemJ, 
rjOI.getMap(args[4]), rjKeyOI, rjValueOI,
    +            _dataMatrix);
    +
    +        train(itemI, _ri, _kNNi, itemJ, rj);
    +        _observedTrainingExamples++;
    +    }
    +
    +    private void recordTrainingInput(final int itemI,
    +            @Nonnull final IntOpenHashTable<Int2FloatOpenHashTable> 
knnItems, final int numKNNItems)
    +            throws HiveException {
    +        ByteBuffer buf = this._inputBuf;
    +        NioStatefullSegment dst = this._fileIO;
    +
    +        if (buf == null) {
    +            // invoke only at task node (initialize is also invoked in 
compilation)
    +            final File file;
    +            try {
    +                file = File.createTempFile("hivemall_slim", ".sgmt"); // 
to save KNN data
    +                file.deleteOnExit();
    +                if (!file.canWrite()) {
    +                    throw new UDFArgumentException("Cannot write a 
temporary file: "
    +                            + file.getAbsolutePath());
    +                }
    +            } catch (IOException ioe) {
    +                throw new UDFArgumentException(ioe);
    +            }
    +
    +            this._inputBuf = buf = ByteBuffer.allocateDirect(8 * 1024 * 
1024); // 8MB
    +            this._fileIO = dst = new NioStatefullSegment(file, false);
    +        }
    +
    +        int recordBytes = SizeOf.INT + SizeOf.INT + SizeOf.INT * 2 * 
knnItems.size()
    +                + (SizeOf.INT + SizeOf.FLOAT) * numKNNItems;
    +        int requiredBytes = SizeOf.INT + recordBytes; // need to allocate 
space for "recordBytes" itself
    +
    +        int remain = buf.remaining();
    +        if (remain < requiredBytes) {
    +            writeBuffer(buf, dst);
    +        }
    +
    +        buf.putInt(recordBytes);
    +        buf.putInt(itemI);
    +        buf.putInt(knnItems.size());
    +
    +        final IMapIterator<Int2FloatOpenHashTable> entries = 
knnItems.entries();
    +        while (entries.next() != -1) {
    +            int user = entries.getKey();
    +            buf.putInt(user);
    +
    +            Int2FloatOpenHashTable ru = entries.getValue();
    +            buf.putInt(ru.size());
    +            final Int2FloatOpenHashTable.IMapIterator itor = ru.entries();
    +            while (itor.next() != -1) {
    +                buf.putInt(itor.getKey());
    +                buf.putFloat(itor.getValue());
    +            }
    +        }
    +    }
    +
    +    private static void writeBuffer(@Nonnull final ByteBuffer srcBuf,
    +            @Nonnull final NioStatefullSegment dst) throws HiveException {
    +        srcBuf.flip();
    +        try {
    +            dst.write(srcBuf);
    +        } catch (IOException e) {
    +            throw new HiveException("Exception causes while writing a 
buffer to file", e);
    +        }
    +        srcBuf.clear();
    +    }
    +
    +    private void train(final int itemI, @Nonnull final 
Int2FloatOpenHashTable ri,
    +            @Nonnull final IntOpenHashTable<Int2FloatOpenHashTable> kNNi, 
final int itemJ,
    +            @Nonnull final Int2FloatOpenHashTable rj) {
    +        final DoKFloatMatrix W = _weightMatrix;
    +
    +        final int N = rj.size();
    +        if (N == 0) {
    +            return;
    +        }
    +
    +        double gradSum = 0.d;
    +        double rateSum = 0.d;
    +        double lossSum = 0.d;
    +
    +        final Int2FloatOpenHashTable.IMapIterator itor = rj.entries();
    +        while (itor.next() != -1) {
    +            int user = itor.getKey();
    +            double ruj = itor.getValue();
    +            double rui = ri.get(user, 0.f);
    +
    +            double eui = rui - predict(user, itemI, kNNi, itemJ, W);
    +            gradSum += ruj * eui;
    +            rateSum += ruj * ruj;
    +            lossSum += eui * eui;
    +        }
    +
    +        gradSum /= N;
    +        rateSum /= N;
    +
    +        double wij = W.get(itemI, itemJ, 0.d);
    +        double loss = lossSum / N + 0.5d * l2 * wij * wij + l1 * wij;
    +        _cvState.incrLoss(loss);
    +
    +        W.set(itemI, itemJ, getUpdateTerm(gradSum, rateSum, l1, l2));
    +    }
    +
    +    private void train(final int itemI,
    +            @Nonnull final IntOpenHashTable<Int2FloatOpenHashTable> 
knnItems, final int itemJ) {
    +        final DoKFloatMatrix A = _dataMatrix;
    +        final DoKFloatMatrix W = _weightMatrix;
    +
    +        final int N = A.numColumns(itemJ);
    +        if (N == 0) {
    +            return;
    +        }
    +
    +        final MutableDouble mutableGradSum = new MutableDouble(0.d);
    +        final MutableDouble mutableRateSum = new MutableDouble(0.d);
    +        final MutableDouble mutableLossSum = new MutableDouble(0.d);
    +
    +        A.eachNonZeroInRow(itemJ, new VectorProcedure() {
    --- End diff --
    
    DoKMatrix A should be converted to CSRMatrix for this access pattern.


---

Reply via email to