http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/recommend/SlimUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/recommend/SlimUDTF.java b/core/src/main/java/hivemall/recommend/SlimUDTF.java new file mode 100644 index 0000000..e205c18 --- /dev/null +++ b/core/src/main/java/hivemall/recommend/SlimUDTF.java @@ -0,0 +1,759 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.recommend; + +import hivemall.UDTFWithOptions; +import hivemall.annotations.VisibleForTesting; +import hivemall.common.ConversionState; +import hivemall.math.matrix.sparse.DoKFloatMatrix; +import hivemall.math.vector.VectorProcedure; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; +import hivemall.utils.collections.maps.IntOpenHashTable; +import hivemall.utils.collections.maps.IntOpenHashTable.IMapIterator; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.io.FileUtils; +import hivemall.utils.io.NioStatefullSegment; +import hivemall.utils.lang.NumberUtils; +import hivemall.utils.lang.Primitives; +import hivemall.utils.lang.SizeOf; +import hivemall.utils.lang.mutable.MutableDouble; +import hivemall.utils.lang.mutable.MutableInt; +import hivemall.utils.lang.mutable.MutableObject; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapred.Counters; +import org.apache.hadoop.mapred.Reporter; + +/** + * Sparse Linear Methods (SLIM) for Top-N Recommender Systems. + * + * <pre> + * Xia Ning and George Karypis, SLIM: Sparse Linear Methods for Top-N Recommender Systems, Proc. ICDM, 2011. + * </pre> + */ +@Description( + name = "train_slim", + value = "_FUNC_( int i, map<int, double> r_i, map<int, map<int, double>> topKRatesOfI, int j, map<int, double> r_j [, constant string options]) " + + "- Returns row index, column index and non-zero weight value of prediction model") +public class SlimUDTF extends UDTFWithOptions { + private static final Log logger = LogFactory.getLog(SlimUDTF.class); + + //-------------------------------------------- + // intput OIs + + private PrimitiveObjectInspector itemIOI; + private PrimitiveObjectInspector itemJOI; + private MapObjectInspector riOI; + private MapObjectInspector rjOI; + + private MapObjectInspector knnItemsOI; + private PrimitiveObjectInspector knnItemsKeyOI; + private MapObjectInspector knnItemsValueOI; + private PrimitiveObjectInspector knnItemsValueKeyOI; + private PrimitiveObjectInspector knnItemsValueValueOI; + + private PrimitiveObjectInspector riKeyOI; + private PrimitiveObjectInspector riValueOI; + + private PrimitiveObjectInspector rjKeyOI; + private PrimitiveObjectInspector rjValueOI; + + //-------------------------------------------- + // hyperparameters + + private double l1; + private double l2; + private int numIterations; + + //-------------------------------------------- + // model parameters and else + + /** item-item weight matrix */ + private transient DoKFloatMatrix _weightMatrix; + + //-------------------------------------------- + // caching for each item i + + private int _previousItemId; + + @Nullable + private transient Int2FloatOpenHashTable _ri; + @Nullable + private transient IntOpenHashTable<Int2FloatOpenHashTable> _kNNi; + /** The number of elements in kNNi */ + @Nullable + private transient MutableInt _nnzKNNi; + + //-------------------------------------------- + // variables for iteration supports + + /** item-user matrix holding the input data */ + @Nullable + private transient DoKFloatMatrix _dataMatrix; + + // used to store KNN data into temporary file for iterative training + private transient NioStatefullSegment _fileIO; + private transient ByteBuffer _inputBuf; + + private ConversionState _cvState; + private long _observedTrainingExamples; + + //-------------------------------------------- + + public SlimUDTF() {} + + @Override + public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + final int numArgs = argOIs.length; + + if (numArgs == 1 && HiveUtils.isConstString(argOIs[0])) {// for -help option + String rawArgs = HiveUtils.getConstString(argOIs[0]); + parseOptions(rawArgs); + } + + if (numArgs != 5 && numArgs != 6) { + throw new UDFArgumentException( + "_FUNC_ takes 5 or 6 arguments: int i, map<int, double> r_i, map<int, map<int, double>> topKRatesOfI, int j, map<int, double> r_j [, constant string options]: " + + Arrays.toString(argOIs)); + } + + this.itemIOI = HiveUtils.asIntCompatibleOI(argOIs[0]); + + this.riOI = HiveUtils.asMapOI(argOIs[1]); + this.riKeyOI = HiveUtils.asIntCompatibleOI((riOI.getMapKeyObjectInspector())); + this.riValueOI = HiveUtils.asPrimitiveObjectInspector((riOI.getMapValueObjectInspector())); + + this.knnItemsOI = HiveUtils.asMapOI(argOIs[2]); + this.knnItemsKeyOI = HiveUtils.asIntCompatibleOI(knnItemsOI.getMapKeyObjectInspector()); + this.knnItemsValueOI = HiveUtils.asMapOI(knnItemsOI.getMapValueObjectInspector()); + this.knnItemsValueKeyOI = HiveUtils.asIntCompatibleOI(knnItemsValueOI.getMapKeyObjectInspector()); + this.knnItemsValueValueOI = HiveUtils.asDoubleCompatibleOI(knnItemsValueOI.getMapValueObjectInspector()); + + this.itemJOI = HiveUtils.asIntCompatibleOI(argOIs[3]); + + this.rjOI = HiveUtils.asMapOI(argOIs[4]); + this.rjKeyOI = HiveUtils.asIntCompatibleOI((rjOI.getMapKeyObjectInspector())); + this.rjValueOI = HiveUtils.asPrimitiveObjectInspector((rjOI.getMapValueObjectInspector())); + + processOptions(argOIs); + + this._observedTrainingExamples = 0L; + this._previousItemId = Integer.MIN_VALUE; + this._weightMatrix = null; + this._dataMatrix = null; + + List<String> fieldNames = new ArrayList<>(); + List<ObjectInspector> fieldOIs = new ArrayList<>(); + + fieldNames.add("j"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldNames.add("nn"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldNames.add("w"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); + + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } + + @Override + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("l1", "l1coefficient", true, + "Coefficient for l1 regularizer [default: 0.001]"); + opts.addOption("l2", "l2coefficient", true, + "Coefficient for l2 regularizer [default: 0.0005]"); + opts.addOption("iters", "iterations", true, + "The number of iterations for coordinate descent [default: 30]"); + opts.addOption("disable_cv", "disable_cvtest", false, + "Whether to disable convergence check [default: enabled]"); + opts.addOption("cv_rate", "convergence_rate", true, + "Threshold to determine convergence [default: 0.005]"); + return opts; + } + + @Override + protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs) + throws UDFArgumentException { + CommandLine cl = null; + double l1 = 0.001d; + double l2 = 0.0005d; + int numIterations = 30; + boolean conversionCheck = true; + double cv_rate = 0.005d; + + if (argOIs.length >= 6) { + String rawArgs = HiveUtils.getConstString(argOIs[5]); + cl = parseOptions(rawArgs); + + l1 = Primitives.parseDouble(cl.getOptionValue("l1"), l1); + if (l1 < 0.d) { + throw new UDFArgumentException("Argument `double l1` must be non-negative: " + l1); + } + + l2 = Primitives.parseDouble(cl.getOptionValue("l2"), l2); + if (l2 < 0.d) { + throw new UDFArgumentException("Argument `double l2` must be non-negative: " + l2); + } + + numIterations = Primitives.parseInt(cl.getOptionValue("iters"), numIterations); + if (numIterations <= 0) { + throw new UDFArgumentException("Argument `int iters` must be greater than 0: " + + numIterations); + } + + conversionCheck = !cl.hasOption("disable_cvtest"); + + cv_rate = Primitives.parseDouble(cl.getOptionValue("cv_rate"), cv_rate); + if (cv_rate <= 0) { + throw new UDFArgumentException( + "Argument `double cv_rate` must be greater than 0.0: " + cv_rate); + } + } + + this.l1 = l1; + this.l2 = l2; + this.numIterations = numIterations; + this._cvState = new ConversionState(conversionCheck, cv_rate); + + return cl; + } + + @Override + public void process(@Nonnull Object[] args) throws HiveException { + if (_weightMatrix == null) {// initialize variables + this._weightMatrix = new DoKFloatMatrix(); + if (numIterations >= 2) { + this._dataMatrix = new DoKFloatMatrix(); + } + this._nnzKNNi = new MutableInt(); + } + + final int itemI = PrimitiveObjectInspectorUtils.getInt(args[0], itemIOI); + + if (itemI != _previousItemId || _ri == null) { + // cache Ri and kNNi + this._ri = int2floatMap(itemI, riOI.getMap(args[1]), riKeyOI, riValueOI, _dataMatrix, + _ri); + this._kNNi = kNNentries(args[2], knnItemsOI, knnItemsKeyOI, knnItemsValueOI, + knnItemsValueKeyOI, knnItemsValueValueOI, _kNNi, _nnzKNNi); + + final int numKNNItems = _nnzKNNi.getValue(); + if (numIterations >= 2 && numKNNItems >= 1) { + recordTrainingInput(itemI, _kNNi, numKNNItems); + } + this._previousItemId = itemI; + } + + int itemJ = PrimitiveObjectInspectorUtils.getInt(args[3], itemJOI); + Int2FloatOpenHashTable rj = int2floatMap(itemJ, rjOI.getMap(args[4]), rjKeyOI, rjValueOI, + _dataMatrix); + + train(itemI, _ri, _kNNi, itemJ, rj); + _observedTrainingExamples++; + } + + private void recordTrainingInput(final int itemI, + @Nonnull final IntOpenHashTable<Int2FloatOpenHashTable> knnItems, final int numKNNItems) + throws HiveException { + ByteBuffer buf = this._inputBuf; + NioStatefullSegment dst = this._fileIO; + + if (buf == null) { + // invoke only at task node (initialize is also invoked in compilation) + final File file; + try { + file = File.createTempFile("hivemall_slim", ".sgmt"); // to save KNN data + file.deleteOnExit(); + if (!file.canWrite()) { + throw new UDFArgumentException("Cannot write a temporary file: " + + file.getAbsolutePath()); + } + } catch (IOException ioe) { + throw new UDFArgumentException(ioe); + } + + this._inputBuf = buf = ByteBuffer.allocateDirect(8 * 1024 * 1024); // 8MB + this._fileIO = dst = new NioStatefullSegment(file, false); + } + + int recordBytes = SizeOf.INT + SizeOf.INT + SizeOf.INT * 2 * knnItems.size() + + (SizeOf.INT + SizeOf.FLOAT) * numKNNItems; + int requiredBytes = SizeOf.INT + recordBytes; // need to allocate space for "recordBytes" itself + + int remain = buf.remaining(); + if (remain < requiredBytes) { + writeBuffer(buf, dst); + } + + buf.putInt(recordBytes); + buf.putInt(itemI); + buf.putInt(knnItems.size()); + + final IMapIterator<Int2FloatOpenHashTable> entries = knnItems.entries(); + while (entries.next() != -1) { + int user = entries.getKey(); + buf.putInt(user); + + Int2FloatOpenHashTable ru = entries.getValue(); + buf.putInt(ru.size()); + final Int2FloatOpenHashTable.IMapIterator itor = ru.entries(); + while (itor.next() != -1) { + buf.putInt(itor.getKey()); + buf.putFloat(itor.getValue()); + } + } + } + + private static void writeBuffer(@Nonnull final ByteBuffer srcBuf, + @Nonnull final NioStatefullSegment dst) throws HiveException { + srcBuf.flip(); + try { + dst.write(srcBuf); + } catch (IOException e) { + throw new HiveException("Exception causes while writing a buffer to file", e); + } + srcBuf.clear(); + } + + private void train(final int itemI, @Nonnull final Int2FloatOpenHashTable ri, + @Nonnull final IntOpenHashTable<Int2FloatOpenHashTable> kNNi, final int itemJ, + @Nonnull final Int2FloatOpenHashTable rj) { + final DoKFloatMatrix W = _weightMatrix; + + final int N = rj.size(); + if (N == 0) { + return; + } + + double gradSum = 0.d; + double rateSum = 0.d; + double lossSum = 0.d; + + final Int2FloatOpenHashTable.IMapIterator itor = rj.entries(); + while (itor.next() != -1) { + int user = itor.getKey(); + double ruj = itor.getValue(); + double rui = ri.get(user, 0.f); + + double eui = rui - predict(user, itemI, kNNi, itemJ, W); + gradSum += ruj * eui; + rateSum += ruj * ruj; + lossSum += eui * eui; + } + + gradSum /= N; + rateSum /= N; + + double wij = W.get(itemI, itemJ, 0.d); + double loss = lossSum / N + 0.5d * l2 * wij * wij + l1 * wij; + _cvState.incrLoss(loss); + + W.set(itemI, itemJ, getUpdateTerm(gradSum, rateSum, l1, l2)); + } + + private void train(final int itemI, + @Nonnull final IntOpenHashTable<Int2FloatOpenHashTable> knnItems, final int itemJ) { + final DoKFloatMatrix A = _dataMatrix; + final DoKFloatMatrix W = _weightMatrix; + + final int N = A.numColumns(itemJ); + if (N == 0) { + return; + } + + final MutableDouble mutableGradSum = new MutableDouble(0.d); + final MutableDouble mutableRateSum = new MutableDouble(0.d); + final MutableDouble mutableLossSum = new MutableDouble(0.d); + + A.eachNonZeroInRow(itemJ, new VectorProcedure() { + @Override + public void apply(int user, double ruj) { + double rui = A.get(itemI, user, 0.d); + double eui = rui - predict(user, itemI, knnItems, itemJ, W); + + mutableGradSum.addValue(ruj * eui); + mutableRateSum.addValue(ruj * ruj); + mutableLossSum.addValue(eui * eui); + } + }); + + double gradSum = mutableGradSum.getValue() / N; + double rateSum = mutableRateSum.getValue() / N; + + double wij = W.get(itemI, itemJ, 0.d); + double loss = mutableLossSum.getValue() / N + 0.5 * l2 * wij * wij + l1 * wij; + _cvState.incrLoss(loss); + + W.set(itemI, itemJ, getUpdateTerm(gradSum, rateSum, l1, l2)); + } + + private static double predict(final int user, final int itemI, + @Nonnull final IntOpenHashTable<Int2FloatOpenHashTable> knnItems, + final int excludeIndex, @Nonnull final DoKFloatMatrix weightMatrix) { + final Int2FloatOpenHashTable kNNu = knnItems.get(user); + if (kNNu == null) { + return 0.d; + } + + double pred = 0.d; + final Int2FloatOpenHashTable.IMapIterator itor = kNNu.entries(); + while (itor.next() != -1) { + final int itemK = itor.getKey(); + if (itemK == excludeIndex) { + continue; + } + float ruk = itor.getValue(); + pred += ruk * weightMatrix.get(itemI, itemK, 0.d); + } + return pred; + } + + private static double getUpdateTerm(final double gradSum, final double rateSum, + final double l1, final double l2) { + double update = 0.d; + if (Math.abs(gradSum) > l1) { + if (gradSum > 0.d) { + update = (gradSum - l1) / (rateSum + l2); + } else { + update = (gradSum + l1) / (rateSum + l2); + } + // non-negative constraints + if (update < 0.d) { + update = 0.d; + } + } + return update; + } + + @Override + public void close() throws HiveException { + finalizeTraining(); + forwardModel(); + this._weightMatrix = null; + } + + @VisibleForTesting + void finalizeTraining() throws HiveException { + if (numIterations > 1) { + this._ri = null; + this._kNNi = null; + + runIterativeTraining(); + + this._dataMatrix = null; + } + } + + private void runIterativeTraining() throws HiveException { + final ByteBuffer buf = this._inputBuf; + final NioStatefullSegment dst = this._fileIO; + assert (buf != null); + assert (dst != null); + + final Reporter reporter = getReporter(); + final Counters.Counter iterCounter = (reporter == null) ? null : reporter.getCounter( + "hivemall.recommend.slim$Counter", "iteration"); + + try { + if (dst.getPosition() == 0L) {// run iterations w/o temporary file + if (buf.position() == 0) { + return; // no training example + } + buf.flip(); + for (int iter = 2; iter < numIterations; iter++) { + _cvState.next(); + reportProgress(reporter); + setCounterValue(iterCounter, iter); + + while (buf.remaining() > 0) { + int recordBytes = buf.getInt(); + assert (recordBytes > 0) : recordBytes; + replayTrain(buf); + } + buf.rewind(); + if (_cvState.isConverged(_observedTrainingExamples)) { + break; + } + } + logger.info("Performed " + + _cvState.getCurrentIteration() + + " iterations of " + + NumberUtils.formatNumber(_observedTrainingExamples) + + " training examples on memory (thus " + + NumberUtils.formatNumber(_observedTrainingExamples + * _cvState.getCurrentIteration()) + " training updates in total) "); + + } else { // read training examples in the temporary file and invoke train for each example + // write KNNi in buffer to a temporary file + if (buf.remaining() > 0) { + writeBuffer(buf, dst); + } + + try { + dst.flush(); + } catch (IOException e) { + throw new HiveException("Failed to flush a file: " + + dst.getFile().getAbsolutePath(), e); + } + + if (logger.isInfoEnabled()) { + File tmpFile = dst.getFile(); + logger.info("Wrote KNN entries of axis items to a temporary file for iterative training: " + + tmpFile.getAbsolutePath() + + " (" + + FileUtils.prettyFileSize(tmpFile) + + ")"); + } + + // run iterations + for (int iter = 2; iter < numIterations; iter++) { + _cvState.next(); + setCounterValue(iterCounter, iter); + + buf.clear(); + dst.resetPosition(); + while (true) { + reportProgress(reporter); + // load a KNNi to a buffer in the temporary file + final int bytesRead; + try { + bytesRead = dst.read(buf); + } catch (IOException e) { + throw new HiveException("Failed to read a file: " + + dst.getFile().getAbsolutePath(), e); + } + if (bytesRead == 0) { // reached file EOF + break; + } + assert (bytesRead > 0) : bytesRead; + + // reads training examples from a buffer + buf.flip(); + int remain = buf.remaining(); + if (remain < SizeOf.INT) { + throw new HiveException("Illegal file format was detected"); + } + while (remain >= SizeOf.INT) { + int pos = buf.position(); + int recordBytes = buf.getInt(); + remain -= SizeOf.INT; + if (remain < recordBytes) { + buf.position(pos); + break; + } + + replayTrain(buf); + remain -= recordBytes; + } + buf.compact(); + } + if (_cvState.isConverged(_observedTrainingExamples)) { + break; + } + } + logger.info("Performed " + + _cvState.getCurrentIteration() + + " iterations of " + + NumberUtils.formatNumber(_observedTrainingExamples) + + " training examples on memory and KNNi data on secondary storage (thus " + + NumberUtils.formatNumber(_observedTrainingExamples + * _cvState.getCurrentIteration()) + " training updates in total) "); + + } + } catch (Throwable e) { + throw new HiveException("Exception caused in the iterative training", e); + } finally { + // delete the temporary file and release resources + try { + dst.close(true); + } catch (IOException e) { + throw new HiveException("Failed to close a file: " + + dst.getFile().getAbsolutePath(), e); + } + this._inputBuf = null; + this._fileIO = null; + } + } + + private void replayTrain(@Nonnull final ByteBuffer buf) { + final int itemI = buf.getInt(); + final int knnSize = buf.getInt(); + + final IntOpenHashTable<Int2FloatOpenHashTable> knnItems = new IntOpenHashTable<>(1024); + final Set<Integer> pairItems = new HashSet<>(); + for (int i = 0; i < knnSize; i++) { + int user = buf.getInt(); + int ruSize = buf.getInt(); + Int2FloatOpenHashTable ru = new Int2FloatOpenHashTable(ruSize); + ru.defaultReturnValue(0.f); + + for (int j = 0; j < ruSize; j++) { + int itemK = buf.getInt(); + pairItems.add(itemK); + float ruk = buf.getFloat(); + ru.put(itemK, ruk); + } + knnItems.put(user, ru); + } + + for (int itemJ : pairItems) { + train(itemI, knnItems, itemJ); + } + } + + private void forwardModel() throws HiveException { + final IntWritable f0 = new IntWritable(); // i + final IntWritable f1 = new IntWritable(); // nn + final FloatWritable f2 = new FloatWritable(); // w + final Object[] forwardObj = new Object[] {f0, f1, f2}; + + final MutableObject<HiveException> catched = new MutableObject<>(); + _weightMatrix.eachNonZeroCell(new VectorProcedure() { + @Override + public void apply(int i, int j, float value) { + if (value == 0.f) { + return; + } + f0.set(i); + f1.set(j); + f2.set(value); + try { + forward(forwardObj); + } catch (HiveException e) { + catched.setIfAbsent(e); + } + } + }); + HiveException ex = catched.get(); + if (ex != null) { + throw ex; + } + logger.info("Forwarded SLIM's weights matrix"); + } + + @Nonnull + private static IntOpenHashTable<Int2FloatOpenHashTable> kNNentries( + @Nonnull final Object kNNiObj, @Nonnull final MapObjectInspector knnItemsOI, + @Nonnull final PrimitiveObjectInspector knnItemsKeyOI, + @Nonnull final MapObjectInspector knnItemsValueOI, + @Nonnull final PrimitiveObjectInspector knnItemsValueKeyOI, + @Nonnull final PrimitiveObjectInspector knnItemsValueValueOI, + @Nullable IntOpenHashTable<Int2FloatOpenHashTable> knnItems, + @Nonnull final MutableInt nnzKNNi) { + if (knnItems == null) { + knnItems = new IntOpenHashTable<>(1024); + } else { + knnItems.clear(); + } + + int numElementOfKNNItems = 0; + for (Map.Entry<?, ?> entry : knnItemsOI.getMap(kNNiObj).entrySet()) { + int user = PrimitiveObjectInspectorUtils.getInt(entry.getKey(), knnItemsKeyOI); + Int2FloatOpenHashTable ru = int2floatMap(knnItemsValueOI.getMap(entry.getValue()), + knnItemsValueKeyOI, knnItemsValueValueOI); + knnItems.put(user, ru); + numElementOfKNNItems += ru.size(); + } + + nnzKNNi.setValue(numElementOfKNNItems); + return knnItems; + } + + @Nonnull + private static Int2FloatOpenHashTable int2floatMap(@Nonnull final Map<?, ?> map, + @Nonnull final PrimitiveObjectInspector keyOI, + @Nonnull final PrimitiveObjectInspector valueOI) { + final Int2FloatOpenHashTable result = new Int2FloatOpenHashTable(map.size()); + result.defaultReturnValue(0.f); + + for (Map.Entry<?, ?> entry : map.entrySet()) { + float v = PrimitiveObjectInspectorUtils.getFloat(entry.getValue(), valueOI); + if (v == 0.f) { + continue; + } + int k = PrimitiveObjectInspectorUtils.getInt(entry.getKey(), keyOI); + result.put(k, v); + } + + return result; + } + + @Nonnull + private static Int2FloatOpenHashTable int2floatMap(final int item, + @Nonnull final Map<?, ?> map, @Nonnull final PrimitiveObjectInspector keyOI, + @Nonnull final PrimitiveObjectInspector valueOI, + @Nullable final DoKFloatMatrix dataMatrix) { + return int2floatMap(item, map, keyOI, valueOI, dataMatrix, null); + } + + @Nonnull + private static Int2FloatOpenHashTable int2floatMap(final int item, + @Nonnull final Map<?, ?> map, @Nonnull final PrimitiveObjectInspector keyOI, + @Nonnull final PrimitiveObjectInspector valueOI, + @Nullable final DoKFloatMatrix dataMatrix, @Nullable Int2FloatOpenHashTable dst) { + if (dst == null) { + dst = new Int2FloatOpenHashTable(map.size()); + dst.defaultReturnValue(0.f); + } else { + dst.clear(); + } + + for (Map.Entry<?, ?> entry : map.entrySet()) { + float rating = PrimitiveObjectInspectorUtils.getFloat(entry.getValue(), valueOI); + if (rating == 0.f) { + continue; + } + int user = PrimitiveObjectInspectorUtils.getInt(entry.getKey(), keyOI); + dst.put(user, rating); + if (dataMatrix != null) { + dataMatrix.set(item, user, rating); + } + } + + return dst; + } +}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/collections/maps/Int2DoubleOpenHashTable.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/maps/Int2DoubleOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Int2DoubleOpenHashTable.java new file mode 100644 index 0000000..3b5585e --- /dev/null +++ b/core/src/main/java/hivemall/utils/collections/maps/Int2DoubleOpenHashTable.java @@ -0,0 +1,427 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.collections.maps; + +import hivemall.utils.math.Primes; + +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.util.Arrays; + +/** + * An open-addressing hash table using double hashing. + * + * <pre> + * Primary hash function: h1(k) = k mod m + * Secondary hash function: h2(k) = 1 + (k mod(m-2)) + * </pre> + * + * @see http://en.wikipedia.org/wiki/Double_hashing + */ +public class Int2DoubleOpenHashTable implements Externalizable { + + protected static final byte FREE = 0; + protected static final byte FULL = 1; + protected static final byte REMOVED = 2; + + private static final float DEFAULT_LOAD_FACTOR = 0.75f; + private static final float DEFAULT_GROW_FACTOR = 2.0f; + + protected final transient float _loadFactor; + protected final transient float _growFactor; + + protected int _used = 0; + protected int _threshold; + protected double defaultReturnValue = -1.d; + + protected int[] _keys; + protected double[] _values; + protected byte[] _states; + + protected Int2DoubleOpenHashTable(int size, float loadFactor, float growFactor, + boolean forcePrime) { + if (size < 1) { + throw new IllegalArgumentException(); + } + this._loadFactor = loadFactor; + this._growFactor = growFactor; + int actualSize = forcePrime ? Primes.findLeastPrimeNumber(size) : size; + this._keys = new int[actualSize]; + this._values = new double[actualSize]; + this._states = new byte[actualSize]; + this._threshold = (int) (actualSize * _loadFactor); + } + + public Int2DoubleOpenHashTable(int size, float loadFactor, float growFactor) { + this(size, loadFactor, growFactor, true); + } + + public Int2DoubleOpenHashTable(int size) { + this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true); + } + + /** + * Only for {@link Externalizable} + */ + public Int2DoubleOpenHashTable() {// required for serialization + this._loadFactor = DEFAULT_LOAD_FACTOR; + this._growFactor = DEFAULT_GROW_FACTOR; + } + + public void defaultReturnValue(double v) { + this.defaultReturnValue = v; + } + + public boolean containsKey(final int key) { + return findKey(key) >= 0; + } + + /** + * @return -1.d if not found + */ + public double get(final int key) { + return get(key, defaultReturnValue); + } + + public double get(final int key, final double defaultValue) { + final int i = findKey(key); + if (i < 0) { + return defaultValue; + } + return _values[i]; + } + + public double put(final int key, final double value) { + final int hash = keyHash(key); + int keyLength = _keys.length; + int keyIdx = hash % keyLength; + + boolean expanded = preAddEntry(keyIdx); + if (expanded) { + keyLength = _keys.length; + keyIdx = hash % keyLength; + } + + final int[] keys = _keys; + final double[] values = _values; + final byte[] states = _states; + + if (states[keyIdx] == FULL) {// double hashing + if (keys[keyIdx] == key) { + double old = values[keyIdx]; + values[keyIdx] = value; + return old; + } + // try second hash + final int decr = 1 + (hash % (keyLength - 2)); + for (;;) { + keyIdx -= decr; + if (keyIdx < 0) { + keyIdx += keyLength; + } + if (isFree(keyIdx, key)) { + break; + } + if (states[keyIdx] == FULL && keys[keyIdx] == key) { + double old = values[keyIdx]; + values[keyIdx] = value; + return old; + } + } + } + keys[keyIdx] = key; + values[keyIdx] = value; + states[keyIdx] = FULL; + ++_used; + return defaultReturnValue; + } + + /** Return weather the required slot is free for new entry */ + protected boolean isFree(final int index, final int key) { + final byte stat = _states[index]; + if (stat == FREE) { + return true; + } + if (stat == REMOVED && _keys[index] == key) { + return true; + } + return false; + } + + /** @return expanded or not */ + protected boolean preAddEntry(final int index) { + if ((_used + 1) >= _threshold) {// too filled + int newCapacity = Math.round(_keys.length * _growFactor); + ensureCapacity(newCapacity); + return true; + } + return false; + } + + protected int findKey(final int key) { + final int[] keys = _keys; + final byte[] states = _states; + final int keyLength = keys.length; + + final int hash = keyHash(key); + int keyIdx = hash % keyLength; + if (states[keyIdx] != FREE) { + if (states[keyIdx] == FULL && keys[keyIdx] == key) { + return keyIdx; + } + // try second hash + final int decr = 1 + (hash % (keyLength - 2)); + for (;;) { + keyIdx -= decr; + if (keyIdx < 0) { + keyIdx += keyLength; + } + if (isFree(keyIdx, key)) { + return -1; + } + if (states[keyIdx] == FULL && keys[keyIdx] == key) { + return keyIdx; + } + } + } + return -1; + } + + public double remove(final int key) { + final int[] keys = _keys; + final double[] values = _values; + final byte[] states = _states; + final int keyLength = keys.length; + + final int hash = keyHash(key); + int keyIdx = hash % keyLength; + if (states[keyIdx] != FREE) { + if (states[keyIdx] == FULL && keys[keyIdx] == key) { + double old = values[keyIdx]; + states[keyIdx] = REMOVED; + --_used; + return old; + } + // second hash + final int decr = 1 + (hash % (keyLength - 2)); + for (;;) { + keyIdx -= decr; + if (keyIdx < 0) { + keyIdx += keyLength; + } + if (states[keyIdx] == FREE) { + return defaultReturnValue; + } + if (states[keyIdx] == FULL && keys[keyIdx] == key) { + double old = values[keyIdx]; + states[keyIdx] = REMOVED; + --_used; + return old; + } + } + } + return defaultReturnValue; + } + + public int size() { + return _used; + } + + public void clear() { + Arrays.fill(_states, FREE); + this._used = 0; + } + + public IMapIterator entries() { + return new MapIterator(); + } + + @Override + public String toString() { + int len = size() * 10 + 2; + StringBuilder buf = new StringBuilder(len); + buf.append('{'); + IMapIterator i = entries(); + while (i.next() != -1) { + buf.append(i.getKey()); + buf.append('='); + buf.append(i.getValue()); + if (i.hasNext()) { + buf.append(','); + } + } + buf.append('}'); + return buf.toString(); + } + + protected void ensureCapacity(final int newCapacity) { + int prime = Primes.findLeastPrimeNumber(newCapacity); + rehash(prime); + this._threshold = Math.round(prime * _loadFactor); + } + + private void rehash(final int newCapacity) { + int oldCapacity = _keys.length; + if (newCapacity <= oldCapacity) { + throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity); + } + final int[] newkeys = new int[newCapacity]; + final double[] newValues = new double[newCapacity]; + final byte[] newStates = new byte[newCapacity]; + int used = 0; + for (int i = 0; i < oldCapacity; i++) { + if (_states[i] == FULL) { + used++; + final int k = _keys[i]; + final double v = _values[i]; + final int hash = keyHash(k); + int keyIdx = hash % newCapacity; + if (newStates[keyIdx] == FULL) {// second hashing + int decr = 1 + (hash % (newCapacity - 2)); + while (newStates[keyIdx] != FREE) { + keyIdx -= decr; + if (keyIdx < 0) { + keyIdx += newCapacity; + } + } + } + newkeys[keyIdx] = k; + newValues[keyIdx] = v; + newStates[keyIdx] = FULL; + } + } + this._keys = newkeys; + this._values = newValues; + this._states = newStates; + this._used = used; + } + + private static int keyHash(int key) { + return key & 0x7fffffff; + } + + public void writeExternal(ObjectOutput out) throws IOException { + out.writeInt(_threshold); + out.writeInt(_used); + + out.writeInt(_keys.length); + IMapIterator i = entries(); + while (i.next() != -1) { + out.writeInt(i.getKey()); + out.writeDouble(i.getValue()); + } + } + + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + this._threshold = in.readInt(); + this._used = in.readInt(); + + int keylen = in.readInt(); + int[] keys = new int[keylen]; + double[] values = new double[keylen]; + byte[] states = new byte[keylen]; + for (int i = 0; i < _used; i++) { + int k = in.readInt(); + double v = in.readDouble(); + int hash = keyHash(k); + int keyIdx = hash % keylen; + if (states[keyIdx] != FREE) {// second hash + int decr = 1 + (hash % (keylen - 2)); + for (;;) { + keyIdx -= decr; + if (keyIdx < 0) { + keyIdx += keylen; + } + if (states[keyIdx] == FREE) { + break; + } + } + } + states[keyIdx] = FULL; + keys[keyIdx] = k; + values[keyIdx] = v; + } + this._keys = keys; + this._values = values; + this._states = states; + } + + public interface IMapIterator { + + public boolean hasNext(); + + /** + * @return -1 if not found + */ + public int next(); + + public int getKey(); + + public double getValue(); + + } + + private final class MapIterator implements IMapIterator { + + int nextEntry; + int lastEntry = -1; + + MapIterator() { + this.nextEntry = nextEntry(0); + } + + /** find the index of next full entry */ + int nextEntry(int index) { + while (index < _keys.length && _states[index] != FULL) { + index++; + } + return index; + } + + public boolean hasNext() { + return nextEntry < _keys.length; + } + + public int next() { + if (!hasNext()) { + return -1; + } + int curEntry = nextEntry; + this.lastEntry = curEntry; + this.nextEntry = nextEntry(curEntry + 1); + return curEntry; + } + + public int getKey() { + if (lastEntry == -1) { + throw new IllegalStateException(); + } + return _keys[lastEntry]; + } + + public double getValue() { + if (lastEntry == -1) { + throw new IllegalStateException(); + } + return _values[lastEntry]; + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java index e9b5c8a..22de115 100644 --- a/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java +++ b/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java @@ -90,23 +90,27 @@ public class Int2FloatOpenHashTable implements Externalizable { this.defaultReturnValue = v; } - public boolean containsKey(int key) { + public boolean containsKey(final int key) { return findKey(key) >= 0; } /** * @return -1.f if not found */ - public float get(int key) { - int i = findKey(key); + public float get(final int key) { + return get(key, defaultReturnValue); + } + + public float get(final int key, final float defaultValue) { + final int i = findKey(key); if (i < 0) { - return defaultReturnValue; + return defaultValue; } return _values[i]; } - public float put(int key, float value) { - int hash = keyHash(key); + public float put(final int key, final float value) { + final int hash = keyHash(key); int keyLength = _keys.length; int keyIdx = hash % keyLength; @@ -116,9 +120,9 @@ public class Int2FloatOpenHashTable implements Externalizable { keyIdx = hash % keyLength; } - int[] keys = _keys; - float[] values = _values; - byte[] states = _states; + final int[] keys = _keys; + final float[] values = _values; + final byte[] states = _states; if (states[keyIdx] == FULL) {// double hashing if (keys[keyIdx] == key) { @@ -127,7 +131,7 @@ public class Int2FloatOpenHashTable implements Externalizable { return old; } // try second hash - int decr = 1 + (hash % (keyLength - 2)); + final int decr = 1 + (hash % (keyLength - 2)); for (;;) { keyIdx -= decr; if (keyIdx < 0) { @@ -151,8 +155,8 @@ public class Int2FloatOpenHashTable implements Externalizable { } /** Return weather the required slot is free for new entry */ - protected boolean isFree(int index, int key) { - byte stat = _states[index]; + protected boolean isFree(final int index, final int key) { + final byte stat = _states[index]; if (stat == FREE) { return true; } @@ -163,7 +167,7 @@ public class Int2FloatOpenHashTable implements Externalizable { } /** @return expanded or not */ - protected boolean preAddEntry(int index) { + protected boolean preAddEntry(final int index) { if ((_used + 1) >= _threshold) {// too filled int newCapacity = Math.round(_keys.length * _growFactor); ensureCapacity(newCapacity); @@ -172,19 +176,19 @@ public class Int2FloatOpenHashTable implements Externalizable { return false; } - protected int findKey(int key) { - int[] keys = _keys; - byte[] states = _states; - int keyLength = keys.length; + protected int findKey(final int key) { + final int[] keys = _keys; + final byte[] states = _states; + final int keyLength = keys.length; - int hash = keyHash(key); + final int hash = keyHash(key); int keyIdx = hash % keyLength; if (states[keyIdx] != FREE) { if (states[keyIdx] == FULL && keys[keyIdx] == key) { return keyIdx; } // try second hash - int decr = 1 + (hash % (keyLength - 2)); + final int decr = 1 + (hash % (keyLength - 2)); for (;;) { keyIdx -= decr; if (keyIdx < 0) { @@ -201,13 +205,13 @@ public class Int2FloatOpenHashTable implements Externalizable { return -1; } - public float remove(int key) { - int[] keys = _keys; - float[] values = _values; - byte[] states = _states; - int keyLength = keys.length; + public float remove(final int key) { + final int[] keys = _keys; + final float[] values = _values; + final byte[] states = _states; + final int keyLength = keys.length; - int hash = keyHash(key); + final int hash = keyHash(key); int keyIdx = hash % keyLength; if (states[keyIdx] != FREE) { if (states[keyIdx] == FULL && keys[keyIdx] == key) { @@ -217,7 +221,7 @@ public class Int2FloatOpenHashTable implements Externalizable { return old; } // second hash - int decr = 1 + (hash % (keyLength - 2)); + final int decr = 1 + (hash % (keyLength - 2)); for (;;) { keyIdx -= decr; if (keyIdx < 0) { @@ -242,6 +246,9 @@ public class Int2FloatOpenHashTable implements Externalizable { } public void clear() { + if (_used == 0) { + return; // no need to clear + } Arrays.fill(_states, FREE); this._used = 0; } @@ -274,21 +281,21 @@ public class Int2FloatOpenHashTable implements Externalizable { this._threshold = Math.round(prime * _loadFactor); } - private void rehash(int newCapacity) { + private void rehash(final int newCapacity) { int oldCapacity = _keys.length; if (newCapacity <= oldCapacity) { throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity); } - int[] newkeys = new int[newCapacity]; - float[] newValues = new float[newCapacity]; - byte[] newStates = new byte[newCapacity]; + final int[] newkeys = new int[newCapacity]; + final float[] newValues = new float[newCapacity]; + final byte[] newStates = new byte[newCapacity]; int used = 0; for (int i = 0; i < oldCapacity; i++) { if (_states[i] == FULL) { used++; int k = _keys[i]; float v = _values[i]; - int hash = keyHash(k); + final int hash = keyHash(k); int keyIdx = hash % newCapacity; if (newStates[keyIdx] == FULL) {// second hashing int decr = 1 + (hash % (newCapacity - 2)); @@ -310,7 +317,7 @@ public class Int2FloatOpenHashTable implements Externalizable { this._used = used; } - private static int keyHash(int key) { + private static int keyHash(final int key) { return key & 0x7fffffff; } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java index 8e87fce..73431d1 100644 --- a/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java +++ b/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java @@ -77,7 +77,10 @@ public final class Int2IntOpenHashTable implements Externalizable { this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true); } - public Int2IntOpenHashTable() {// required for serialization + /** + * Only for {@link Externalizable} + */ + public Int2IntOpenHashTable() { this._loadFactor = DEFAULT_LOAD_FACTOR; this._growFactor = DEFAULT_GROW_FACTOR; } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java index dbade74..1c90ae0 100644 --- a/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java +++ b/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java @@ -58,7 +58,10 @@ public final class IntOpenHashTable<V> implements Externalizable { protected V[] _values; protected byte[] _states; - public IntOpenHashTable() {} // for Externalizable + /** + * Only for {@link Externalizable} + */ + public IntOpenHashTable() {} public IntOpenHashTable(int size) { this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java index b4356ff..115571e 100644 --- a/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java +++ b/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java @@ -78,6 +78,9 @@ public final class Long2DoubleOpenHashTable implements Externalizable { this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true); } + /** + * Only for {@link Externalizable} + */ public Long2DoubleOpenHashTable() {// required for serialization this._loadFactor = DEFAULT_LOAD_FACTOR; this._growFactor = DEFAULT_GROW_FACTOR; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java index 6b0ab59..ba2de76 100644 --- a/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java +++ b/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java @@ -78,7 +78,10 @@ public final class Long2FloatOpenHashTable implements Externalizable { this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true); } - public Long2FloatOpenHashTable() {// required for serialization + /** + * Only for {@link Externalizable} + */ + public Long2FloatOpenHashTable() { this._loadFactor = DEFAULT_LOAD_FACTOR; this._growFactor = DEFAULT_GROW_FACTOR; } @@ -113,7 +116,23 @@ public final class Long2FloatOpenHashTable implements Externalizable { return _values[index]; } + public float _set(final int index, final float value) { + float old = _values[index]; + _values[index] = value; + return old; + } + + public float _remove(final int index) { + _states[index] = REMOVED; + --_used; + return _values[index]; + } + public float put(final long key, final float value) { + return put(key, value, defaultReturnValue); + } + + public float put(final long key, final float value, final float defaultValue) { final int hash = keyHash(key); int keyLength = _keys.length; int keyIdx = hash % keyLength; @@ -155,7 +174,7 @@ public final class Long2FloatOpenHashTable implements Externalizable { values[keyIdx] = value; states[keyIdx] = FULL; ++_used; - return defaultReturnValue; + return defaultValue; } /** Return weather the required slot is free for new entry */ http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java index 1ca4c40..6445231 100644 --- a/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java +++ b/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java @@ -77,6 +77,9 @@ public final class Long2IntOpenHashTable implements Externalizable { this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true); } + /** + * Only for {@link Externalizable} + */ public Long2IntOpenHashTable() {// required for serialization this._loadFactor = DEFAULT_LOAD_FACTOR; this._growFactor = DEFAULT_GROW_FACTOR; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java index 4599bfc..c16567a 100644 --- a/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java +++ b/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java @@ -59,7 +59,10 @@ public final class OpenHashTable<K, V> implements Externalizable { protected V[] _values; protected byte[] _states; - public OpenHashTable() {} // for Externalizable + /** + * Only for {@link Externalizable} + */ + public OpenHashTable() {} public OpenHashTable(int size) { this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/lang/mutable/MutableObject.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/lang/mutable/MutableObject.java b/core/src/main/java/hivemall/utils/lang/mutable/MutableObject.java new file mode 100644 index 0000000..bea2a9d --- /dev/null +++ b/core/src/main/java/hivemall/utils/lang/mutable/MutableObject.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.lang.mutable; + +import javax.annotation.Nullable; + +public final class MutableObject<T> { + + @Nullable + private T _value; + + public MutableObject() {} + + public MutableObject(@Nullable T obj) { + this._value = obj; + } + + public boolean isSet() { + return _value != null; + } + + @Nullable + public T get() { + return _value; + } + + public void set(@Nullable T obj) { + this._value = obj; + } + + public void setIfAbsent(@Nullable T obj) { + if (_value == null) { + this._value = obj; + } + } + + @Override + public boolean equals(@Nullable Object obj) { + if (this == obj) { + return true; + } + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + MutableObject<?> other = (MutableObject<?>) obj; + if (_value == null) { + if (other._value != null) { + return false; + } + } + return _value.equals(other._value); + } + + @Override + public int hashCode() { + return _value == null ? 0 : _value.hashCode(); + } + + @Override + public String toString() { + return _value == null ? "null" : _value.toString(); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/math/MathUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java b/core/src/main/java/hivemall/utils/math/MathUtils.java index ee533dc..71d4c29 100644 --- a/core/src/main/java/hivemall/utils/math/MathUtils.java +++ b/core/src/main/java/hivemall/utils/math/MathUtils.java @@ -43,7 +43,7 @@ import javax.annotation.Nullable; import org.apache.commons.math3.special.Gamma; public final class MathUtils { - private static final double LOG2 = Math.log(2); + public static final double LOG2 = Math.log(2); private MathUtils() {} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java b/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java index 5e8f253..574fc04 100644 --- a/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java +++ b/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java @@ -103,7 +103,7 @@ public class BinaryResponsesMeasuresTest { List<Integer> groundTruth = Arrays.asList(1, 2, 4); double actual = BinaryResponsesMeasures.ReciprocalRank(rankedList, groundTruth, - rankedList.size()); + rankedList.size()); Assert.assertEquals(1.0d, actual, 0.0001d); Collections.reverse(rankedList); @@ -115,6 +115,22 @@ public class BinaryResponsesMeasuresTest { Assert.assertEquals(0.0d, actual, 0.0001d); } + public void testHit() { + List<Integer> rankedList = Arrays.asList(1, 3, 2, 6); + List<Integer> groundTruth = Arrays.asList(1, 2, 4); + + double actual = BinaryResponsesMeasures.Hit(rankedList, groundTruth, rankedList.size()); + Assert.assertEquals(1.d, actual, 0.0001d); + + actual = BinaryResponsesMeasures.Hit(rankedList, groundTruth, 2); + Assert.assertEquals(1.d, actual, 0.0001d); + + // not hitting case + rankedList = Arrays.asList(5, 6); + actual = BinaryResponsesMeasures.Hit(rankedList, groundTruth, 2); + Assert.assertEquals(0.d, actual, 0.0001d); + } + @Test public void testAP() { List<Integer> rankedList = Arrays.asList(1, 3, 2, 6); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/test/java/hivemall/evaluation/GradedResponsesMeasuresTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/evaluation/GradedResponsesMeasuresTest.java b/core/src/test/java/hivemall/evaluation/GradedResponsesMeasuresTest.java index 6a7cc9d..96ac030 100644 --- a/core/src/test/java/hivemall/evaluation/GradedResponsesMeasuresTest.java +++ b/core/src/test/java/hivemall/evaluation/GradedResponsesMeasuresTest.java @@ -18,12 +18,12 @@ */ package hivemall.evaluation; -import org.junit.Assert; -import org.junit.Test; - import java.util.Arrays; import java.util.List; +import org.junit.Assert; +import org.junit.Test; + public class GradedResponsesMeasuresTest { @Test http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java b/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java index decd7df..af3f024 100644 --- a/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java +++ b/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java @@ -225,7 +225,6 @@ public class MatrixBuilderTest { Assert.assertEquals(Double.NaN, csc2.get(5, 4, Double.NaN), 0.d); } - @Test public void testDoKMatrixFromLibSVM() { Matrix matrix = dokMatrixFromLibSVM(); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/test/java/hivemall/math/matrix/sparse/DoKFloatMatrixTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/math/matrix/sparse/DoKFloatMatrixTest.java b/core/src/test/java/hivemall/math/matrix/sparse/DoKFloatMatrixTest.java new file mode 100644 index 0000000..c9e6afd --- /dev/null +++ b/core/src/test/java/hivemall/math/matrix/sparse/DoKFloatMatrixTest.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix.sparse; + +import java.util.Random; + +import org.junit.Assert; +import org.junit.Test; + +public class DoKFloatMatrixTest { + + @Test + public void testGetSet() { + DoKFloatMatrix matrix = new DoKFloatMatrix(); + Random rnd = new Random(43); + + for (int i = 0; i < 1000; i++) { + int row = Math.abs(rnd.nextInt()); + int col = Math.abs(rnd.nextInt()); + double v = rnd.nextDouble(); + matrix.set(row, col, v); + Assert.assertEquals(v, matrix.get(row, col), 0.00001d); + } + + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/test/java/hivemall/recommend/SlimUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/recommend/SlimUDTFTest.java b/core/src/test/java/hivemall/recommend/SlimUDTFTest.java new file mode 100644 index 0000000..00b78f0 --- /dev/null +++ b/core/src/test/java/hivemall/recommend/SlimUDTFTest.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.recommend; + +import java.util.HashMap; +import java.util.Map; + +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.junit.Test; + +public class SlimUDTFTest { + @Test + public void testAllSamples() throws HiveException { + SlimUDTF slim = new SlimUDTF(); + ObjectInspector itemIOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector; + ObjectInspector itemJOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector; + + ObjectInspector itemIRatesOI = ObjectInspectorFactory.getStandardMapObjectInspector( + PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaFloatObjectInspector); + ObjectInspector itemJRatesOI = ObjectInspectorFactory.getStandardMapObjectInspector( + PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaFloatObjectInspector); + ObjectInspector topKRatesOfIOI = ObjectInspectorFactory.getStandardMapObjectInspector( + PrimitiveObjectInspectorFactory.javaIntObjectInspector, + ObjectInspectorFactory.getStandardMapObjectInspector( + PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaFloatObjectInspector)); + ObjectInspector optionArgumentOI = ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-l2 0.01 -l1 0.01"); + + ObjectInspector[] argOIs = {itemIOI, itemIRatesOI, topKRatesOfIOI, itemJOI, itemJRatesOI, + optionArgumentOI}; + + slim.initialize(argOIs); + int numUser = 4; + int numItem = 5; + + float[][] data = { {1.f, 4.f, 0.f, 0.f, 0.f}, {0.f, 3.f, 0.f, 1.f, 2.f}, + {2.f, 2.f, 0.f, 0.f, 3.f}, {0.f, 1.f, 1.f, 0.f, 0.f}}; + + for (int i = 0; i < numItem; i++) { + Map<Integer, Float> Ri = new HashMap<>(); + for (int u = 0; u < numUser; u++) { + if (data[u][i] != 0.) { + Ri.put(u, data[u][i]); + } + } + + // most similar data + Map<Integer, Map<Integer, Float>> knnRatesOfI = new HashMap<>(); + for (int u = 0; u < numUser; u++) { + Map<Integer, Float> Ru = new HashMap<>(); + for (int k = 0; k < numItem; k++) { + if (k == i) + continue; + Ru.put(k, data[u][k]); + } + knnRatesOfI.put(u, Ru); + } + + for (int j = 0; j < numItem; j++) { + if (i == j) + continue; + Map<Integer, Float> Rj = new HashMap<>(); + for (int u = 0; u < numUser; u++) { + if (data[u][j] != 0.) { + Rj.put(u, data[u][j]); + } + } + + Object[] args = {i, Ri, knnRatesOfI, j, Rj}; + slim.process(args); + } + } + slim.finalizeTraining(); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/docs/gitbook/SUMMARY.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/SUMMARY.md b/docs/gitbook/SUMMARY.md index 3d640f8..8b76a7f 100644 --- a/docs/gitbook/SUMMARY.md +++ b/docs/gitbook/SUMMARY.md @@ -155,6 +155,7 @@ * [Item-based Collaborative Filtering](recommend/movielens_cf.md) * [Matrix Factorization](recommend/movielens_mf.md) * [Factorization Machine](recommend/movielens_fm.md) + * [SLIM for Fast Top-K Recommendation](recommend/movielens_slim.md) * [10-fold Cross Validation (Matrix Factorization)](recommend/movielens_cv.md) ## Part X - Anomaly Detection http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/docs/gitbook/recommend/item_based_cf.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/recommend/item_based_cf.md b/docs/gitbook/recommend/item_based_cf.md index 053b225..dcd4f57 100644 --- a/docs/gitbook/recommend/item_based_cf.md +++ b/docs/gitbook/recommend/item_based_cf.md @@ -325,7 +325,7 @@ similarity as ( o.other, cosine_similarity(t1.feature_vector, t2.feature_vector) as similarity from - cooccurrence_top100 oã + cooccurrence_top100 o -- cooccurrence_upper_triangular o JOIN item_features t1 ON (o.itemid = t1.itemid) JOIN item_features t2 ON (o.other = t2.itemid) @@ -652,7 +652,8 @@ partial_result as ( -- launch DIMSUM in a MapReduce fashion item_features f left outer join item_magnitude m ), -similarity as ( -- reduce (i.e., sum up) mappers' partial results +similarity as ( + -- reduce (i.e., sum up) mappers' partial results select itemid, other, @@ -702,7 +703,8 @@ partial_result as ( item_features f left outer join item_magnitude m ), -similarity_upper_triangular as ( -- if similarity of (i1, i2) pair is in this table, (i2, i1)'s similarity is omitted +similarity_upper_triangular as ( + -- if similarity of (i1, i2) pair is in this table, (i2, i1)'s similarity is omitted select itemid, other, http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/docs/gitbook/recommend/movielens_cf.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/recommend/movielens_cf.md b/docs/gitbook/recommend/movielens_cf.md index 1cf5aee..0602611 100644 --- a/docs/gitbook/recommend/movielens_cf.md +++ b/docs/gitbook/recommend/movielens_cf.md @@ -66,7 +66,8 @@ partial_result as ( -- launch DIMSUM in a MapReduce fashion movie_features f left outer join movie_magnitude m ), -similarity as ( -- reduce (i.e., sum up) mappers' partial results +similarity as ( + -- reduce (i.e., sum up) mappers' partial results select movieid, other, http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/docs/gitbook/recommend/movielens_cv.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/recommend/movielens_cv.md b/docs/gitbook/recommend/movielens_cv.md index 6ac54c7..80c0d19 100644 --- a/docs/gitbook/recommend/movielens_cv.md +++ b/docs/gitbook/recommend/movielens_cv.md @@ -17,7 +17,7 @@ under the License. --> -[Cross-validation](http://en.wikipedia.org/wiki/Cross-validation_(statistics)#k-fold_cross-validationk-fold cross validation) is a model validation technique for assessing how a prediction model will generalize to an independent data set. This example shows a way to perform [k-fold cross validation](http://en.wikipedia.org/wiki/Cross-validation_(statistics)#k-fold_cross-validation) to evaluate prediction performance. +[Cross-validation](http://en.wikipedia.org/wiki/Cross-validation_%28statistics%29) is a model validation technique for assessing how a prediction model will generalize to an independent data set. This example shows a way to perform [k-fold cross validation](http://en.wikipedia.org/wiki/Cross-validation_%28statistics%29#k-fold_cross-validation) to evaluate prediction performance. *Caution:* Matrix factorization is supported in Hivemall v0.3 or later. http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/docs/gitbook/recommend/movielens_fm.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/recommend/movielens_fm.md b/docs/gitbook/recommend/movielens_fm.md index 64039fe..d3d2c82 100644 --- a/docs/gitbook/recommend/movielens_fm.md +++ b/docs/gitbook/recommend/movielens_fm.md @@ -19,6 +19,8 @@ _Caution: Factorization Machine is supported from Hivemall v0.4 or later._ +<!-- toc --> + # Data preparation First of all, please create `ratings` table described in [this article](../recommend/movielens_dataset.html). @@ -89,7 +91,7 @@ set hivevar:factor=10; set hivevar:iters=50; ``` -## Build a prediction mdoel by Factorization Machine +## Build a prediction model by Factorization Machine ```sql drop table fm_model;
