Close #117, Close #111: [HIVEMALL-17] Support SLIM neighborhood-learning recommendation algorithm
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/995b9a88 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/995b9a88 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/995b9a88 Branch: refs/heads/master Commit: 995b9a885f6538138935dbf0fe9aae051ec47f9e Parents: c2b9578 Author: Kento NOZAWA <[email protected]> Authored: Thu Sep 28 12:16:17 2017 +0900 Committer: Makoto Yui <[email protected]> Committed: Thu Sep 28 12:16:45 2017 +0900 ---------------------------------------------------------------------- .../main/java/hivemall/evaluation/AUCUDAF.java | 17 +- .../evaluation/BinaryResponsesMeasures.java | 37 +- .../evaluation/GradedResponsesMeasures.java | 16 +- .../java/hivemall/evaluation/HitRateUDAF.java | 262 +++++++ .../main/java/hivemall/evaluation/MAPUDAF.java | 19 +- .../main/java/hivemall/evaluation/MRRUDAF.java | 19 +- .../main/java/hivemall/evaluation/NDCGUDAF.java | 17 +- .../java/hivemall/evaluation/PrecisionUDAF.java | 24 +- .../java/hivemall/evaluation/RecallUDAF.java | 19 +- .../hivemall/math/matrix/sparse/CSCMatrix.java | 2 + .../hivemall/math/matrix/sparse/CSRMatrix.java | 4 +- .../math/matrix/sparse/DoKFloatMatrix.java | 368 +++++++++ .../hivemall/math/matrix/sparse/DoKMatrix.java | 34 +- .../hivemall/math/vector/VectorProcedure.java | 6 + .../hivemall/mf/BPRMatrixFactorizationUDTF.java | 3 +- .../mf/OnlineMatrixFactorizationUDTF.java | 7 +- .../main/java/hivemall/recommend/SlimUDTF.java | 759 +++++++++++++++++++ .../maps/Int2DoubleOpenHashTable.java | 427 +++++++++++ .../maps/Int2FloatOpenHashTable.java | 71 +- .../collections/maps/Int2IntOpenHashTable.java | 5 +- .../collections/maps/IntOpenHashTable.java | 5 +- .../maps/Long2DoubleOpenHashTable.java | 3 + .../maps/Long2FloatOpenHashTable.java | 23 +- .../collections/maps/Long2IntOpenHashTable.java | 3 + .../utils/collections/maps/OpenHashTable.java | 5 +- .../utils/lang/mutable/MutableObject.java | 83 ++ .../java/hivemall/utils/math/MathUtils.java | 2 +- .../evaluation/BinaryResponsesMeasuresTest.java | 18 +- .../evaluation/GradedResponsesMeasuresTest.java | 6 +- .../hivemall/math/matrix/MatrixBuilderTest.java | 1 - .../math/matrix/sparse/DoKFloatMatrixTest.java | 43 ++ .../java/hivemall/recommend/SlimUDTFTest.java | 99 +++ docs/gitbook/SUMMARY.md | 1 + docs/gitbook/recommend/item_based_cf.md | 8 +- docs/gitbook/recommend/movielens_cf.md | 3 +- docs/gitbook/recommend/movielens_cv.md | 2 +- docs/gitbook/recommend/movielens_fm.md | 4 +- docs/gitbook/recommend/movielens_slim.md | 589 ++++++++++++++ resources/ddl/define-all-as-permanent.hive | 10 + resources/ddl/define-all.hive | 10 + resources/ddl/define-all.spark | 10 + resources/ddl/define-udfs.td.hql | 2 + 42 files changed, 2916 insertions(+), 130 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/AUCUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/evaluation/AUCUDAF.java b/core/src/main/java/hivemall/evaluation/AUCUDAF.java index 7cbdb52..508e36a 100644 --- a/core/src/main/java/hivemall/evaluation/AUCUDAF.java +++ b/core/src/main/java/hivemall/evaluation/AUCUDAF.java @@ -52,7 +52,6 @@ import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.io.LongWritable; @@ -430,7 +429,7 @@ public final class AUCUDAF extends AbstractGenericUDAFResolver { private ListObjectInspector recommendListOI; private ListObjectInspector truthListOI; - private WritableIntObjectInspector recommendSizeOI; + private PrimitiveObjectInspector recommendSizeOI; private StructObjectInspector internalMergeOI; private StructField countField; @@ -448,7 +447,7 @@ public final class AUCUDAF extends AbstractGenericUDAFResolver { this.recommendListOI = (ListObjectInspector) parameters[0]; this.truthListOI = (ListObjectInspector) parameters[1]; if (parameters.length == 3) { - this.recommendSizeOI = (WritableIntObjectInspector) parameters[2]; + this.recommendSizeOI = HiveUtils.asIntegerOI(parameters[2]); } } else {// from partial aggregation StructObjectInspector soi = (StructObjectInspector) parameters[0]; @@ -507,12 +506,12 @@ public final class AUCUDAF extends AbstractGenericUDAFResolver { int recommendSize = recommendList.size(); if (parameters.length == 3) { - recommendSize = recommendSizeOI.get(parameters[2]); - } - if (recommendSize < 0 || recommendSize > recommendList.size()) { - throw new UDFArgumentException( - "The third argument `int recommendSize` must be in [0, " + recommendList.size() - + "]"); + recommendSize = PrimitiveObjectInspectorUtils.getInt(parameters[2], recommendSizeOI); + if (recommendSize < 0) { + throw new UDFArgumentException( + "The third argument `int recommendSize` must be in greather than or equals to 0: " + + recommendSize); + } } myAggr.iterate(recommendList, truthList, recommendSize); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java b/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java index 7c21849..c3b4f6a 100644 --- a/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java +++ b/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java @@ -45,7 +45,7 @@ public final class BinaryResponsesMeasures { */ public static double nDCG(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) { - Preconditions.checkArgument(recommendSize > 0); + Preconditions.checkArgument(recommendSize >= 0); double dcg = 0.d; @@ -92,6 +92,8 @@ public final class BinaryResponsesMeasures { */ public static double Precision(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) { + Preconditions.checkArgument(recommendSize >= 0); + if (rankedList.isEmpty()) { if (groundTruth.isEmpty()) { return 1.d; @@ -99,8 +101,6 @@ public final class BinaryResponsesMeasures { return 0.d; } - Preconditions.checkArgument(recommendSize > 0); // can be zero when groundTruth is empty - int nTruePositive = 0; final int k = Math.min(rankedList.size(), recommendSize); for (int i = 0; i < k; i++) { @@ -135,6 +135,29 @@ public final class BinaryResponsesMeasures { } /** + * Computes Hit@`recommendSize` + * + * @param rankedList a list of ranked item IDs (first item is highest-ranked) + * @param groundTruth a collection of positive/correct item IDs + * @param recommendSize top-`recommendSize` items in `rankedList` are recommended + * @return 1.0 if hit 0.0 if no hit + */ + public static double Hit(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth, + @Nonnegative final int recommendSize) { + Preconditions.checkArgument(recommendSize >= 0); + + final int k = Math.min(rankedList.size(), recommendSize); + for (int i = 0; i < k; i++) { + Object item_id = rankedList.get(i); + if (groundTruth.contains(item_id)) { + return 1.d; + } + } + + return 0.d; + } + + /** * Counts the number of true positives * * @param rankedList a list of ranked item IDs (first item is highest-ranked) @@ -144,7 +167,7 @@ public final class BinaryResponsesMeasures { */ public static int TruePositives(final List<?> rankedList, final List<?> groundTruth, @Nonnegative final int recommendSize) { - Preconditions.checkArgument(recommendSize > 0); + Preconditions.checkArgument(recommendSize >= 0); int nTruePositive = 0; @@ -170,7 +193,7 @@ public final class BinaryResponsesMeasures { */ public static double ReciprocalRank(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) { - Preconditions.checkArgument(recommendSize > 0); + Preconditions.checkArgument(recommendSize >= 0); final int k = Math.min(rankedList.size(), recommendSize); for (int i = 0; i < k; i++) { @@ -193,7 +216,7 @@ public final class BinaryResponsesMeasures { */ public static double AveragePrecision(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) { - Preconditions.checkArgument(recommendSize > 0); + Preconditions.checkArgument(recommendSize >= 0); if (groundTruth.isEmpty()) { if (rankedList.isEmpty()) { @@ -231,7 +254,7 @@ public final class BinaryResponsesMeasures { */ public static double AUC(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) { - Preconditions.checkArgument(recommendSize > 0); + Preconditions.checkArgument(recommendSize >= 0); int nTruePositive = 0, nCorrectPairs = 0; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/GradedResponsesMeasures.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/evaluation/GradedResponsesMeasures.java b/core/src/main/java/hivemall/evaluation/GradedResponsesMeasures.java index 688ba53..5bbbb7e 100644 --- a/core/src/main/java/hivemall/evaluation/GradedResponsesMeasures.java +++ b/core/src/main/java/hivemall/evaluation/GradedResponsesMeasures.java @@ -18,8 +18,12 @@ */ package hivemall.evaluation; +import hivemall.utils.lang.Preconditions; +import hivemall.utils.math.MathUtils; + import java.util.List; +import javax.annotation.Nonnegative; import javax.annotation.Nonnull; /** @@ -32,7 +36,7 @@ public final class GradedResponsesMeasures { private GradedResponsesMeasures() {} public static double nDCG(@Nonnull final List<Double> recommendTopRelScoreList, - @Nonnull final List<Double> truthTopRelScoreList, @Nonnull final int recommendSize) { + @Nonnull final List<Double> truthTopRelScoreList, @Nonnegative final int recommendSize) { double dcg = DCG(recommendTopRelScoreList, recommendSize); double idcg = DCG(truthTopRelScoreList, recommendSize); return dcg / idcg; @@ -45,11 +49,15 @@ public final class GradedResponsesMeasures { * @param recommendSize the number of positive items * @return DCG */ - public static double DCG(final List<Double> topRelScoreList, final int recommendSize) { + public static double DCG(@Nonnull final List<Double> topRelScoreList, + @Nonnegative final int recommendSize) { + Preconditions.checkArgument(recommendSize >= 0); + double dcg = 0.d; - for (int i = 0; i < recommendSize; i++) { + final int k = Math.min(topRelScoreList.size(), recommendSize); + for (int i = 0; i < k; i++) { double relScore = topRelScoreList.get(i); - dcg += ((Math.pow(2, relScore) - 1) * Math.log(2)) / Math.log(i + 2); + dcg += ((Math.pow(2, relScore) - 1) * MathUtils.LOG2) / Math.log(i + 2); } return dcg; } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/HitRateUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/evaluation/HitRateUDAF.java b/core/src/main/java/hivemall/evaluation/HitRateUDAF.java new file mode 100644 index 0000000..6df6087 --- /dev/null +++ b/core/src/main/java/hivemall/evaluation/HitRateUDAF.java @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ +package hivemall.evaluation; + +import hivemall.utils.hadoop.HiveUtils; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.io.LongWritable; + +@Description( + name = "hitrate", + value = "_FUNC_(array rankItems, array correctItems [, const int recommendSize = rankItems.size])" + + " - Returns HitRate") +public final class HitRateUDAF extends AbstractGenericUDAFResolver { + + // prevent instantiation + private HitRateUDAF() {} + + @Override + public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfo) throws SemanticException { + if (typeInfo.length != 2 && typeInfo.length != 3) { + throw new UDFArgumentTypeException(typeInfo.length - 1, + "_FUNC_ takes two or three arguments"); + } + + ListTypeInfo arg1type = HiveUtils.asListTypeInfo(typeInfo[0]); + if (!HiveUtils.isPrimitiveTypeInfo(arg1type.getListElementTypeInfo())) { + throw new UDFArgumentTypeException(0, + "The first argument `array rankItems` is invalid form: " + typeInfo[0]); + } + ListTypeInfo arg2type = HiveUtils.asListTypeInfo(typeInfo[1]); + if (!HiveUtils.isPrimitiveTypeInfo(arg2type.getListElementTypeInfo())) { + throw new UDFArgumentTypeException(1, + "The second argument `array correctItems` is invalid form: " + typeInfo[1]); + } + + return new HitRateUDAF.Evaluator(); + } + + public static class Evaluator extends GenericUDAFEvaluator { + + private ListObjectInspector recommendListOI; + private ListObjectInspector truthListOI; + private PrimitiveObjectInspector recommendSizeOI; + + private StructObjectInspector internalMergeOI; + private StructField countField; + private StructField sumField; + + public Evaluator() {} + + @Override + public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException { + assert (parameters.length == 2 || parameters.length == 3) : parameters.length; + super.init(mode, parameters); + + // initialize input + if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data + this.recommendListOI = (ListObjectInspector) parameters[0]; + this.truthListOI = (ListObjectInspector) parameters[1]; + if (parameters.length == 3) { + this.recommendSizeOI = HiveUtils.asIntegerOI(parameters[2]); + } + } else {// from partial aggregation + StructObjectInspector soi = (StructObjectInspector) parameters[0]; + this.internalMergeOI = soi; + this.countField = soi.getStructFieldRef("count"); + this.sumField = soi.getStructFieldRef("sum"); + } + + // initialize output + final ObjectInspector outputOI; + if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial + outputOI = internalMergeOI(); + } else {// terminate + outputOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; + } + return outputOI; + } + + private static StructObjectInspector internalMergeOI() { + ArrayList<String> fieldNames = new ArrayList<String>(); + ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + + fieldNames.add("sum"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + fieldNames.add("count"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } + + @Override + public HitRateAggregationBuffer getNewAggregationBuffer() throws HiveException { + HitRateAggregationBuffer myAggr = new HitRateAggregationBuffer(); + reset(myAggr); + return myAggr; + } + + @Override + public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + HitRateAggregationBuffer myAggr = (HitRateAggregationBuffer) agg; + myAggr.reset(); + } + + @Override + public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg, + Object[] parameters) throws HiveException { + HitRateAggregationBuffer myAggr = (HitRateAggregationBuffer) agg; + + List<?> recommendList = recommendListOI.getList(parameters[0]); + if (recommendList == null) { + recommendList = Collections.emptyList(); + } + List<?> truthList = truthListOI.getList(parameters[1]); + if (truthList == null) { + return; + } + + int recommendSize = recommendList.size(); + if (parameters.length == 3) { + recommendSize = PrimitiveObjectInspectorUtils.getInt(parameters[2], recommendSizeOI); + if (recommendSize < 0) { + throw new UDFArgumentException( + "The third argument `int recommendSize` must be in greather than or equals to 0: " + + recommendSize); + } + } + + myAggr.iterate(recommendList, truthList, recommendSize); + } + + @Override + public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + HitRateAggregationBuffer myAggr = (HitRateAggregationBuffer) agg; + + Object[] partialResult = new Object[2]; + partialResult[0] = new DoubleWritable(myAggr.sum); + partialResult[1] = new LongWritable(myAggr.count); + return partialResult; + } + + @Override + public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial) + throws HiveException { + if (partial == null) { + return; + } + + Object sumObj = internalMergeOI.getStructFieldData(partial, sumField); + Object countObj = internalMergeOI.getStructFieldData(partial, countField); + double sum = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(sumObj); + long count = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(countObj); + + HitRateAggregationBuffer myAggr = (HitRateAggregationBuffer) agg; + myAggr.merge(sum, count); + } + + @Override + public DoubleWritable terminate(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + HitRateAggregationBuffer myAggr = (HitRateAggregationBuffer) agg; + double result = myAggr.get(); + return new DoubleWritable(result); + } + + } + + public static final class HitRateAggregationBuffer extends + GenericUDAFEvaluator.AbstractAggregationBuffer { + + private double sum; + private long count; + + public HitRateAggregationBuffer() { + super(); + } + + void reset() { + this.sum = 0.d; + this.count = 0; + } + + void merge(double o_sum, long o_count) { + this.sum += o_sum; + this.count += o_count; + } + + double get() { + if (count == 0) { + return 0.d; + } + return sum / count; + } + + void iterate(@Nonnull List<?> recommendList, @Nonnull List<?> truthList, + @Nonnegative int recommendSize) { + this.sum += BinaryResponsesMeasures.Hit(recommendList, truthList, recommendSize); + this.count++; + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/MAPUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/evaluation/MAPUDAF.java b/core/src/main/java/hivemall/evaluation/MAPUDAF.java index 3878684..45e64cb 100644 --- a/core/src/main/java/hivemall/evaluation/MAPUDAF.java +++ b/core/src/main/java/hivemall/evaluation/MAPUDAF.java @@ -38,10 +38,11 @@ import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.io.LongWritable; @@ -80,7 +81,7 @@ public final class MAPUDAF extends AbstractGenericUDAFResolver { private ListObjectInspector recommendListOI; private ListObjectInspector truthListOI; - private WritableIntObjectInspector recommendSizeOI; + private PrimitiveObjectInspector recommendSizeOI; private StructObjectInspector internalMergeOI; private StructField countField; @@ -98,7 +99,7 @@ public final class MAPUDAF extends AbstractGenericUDAFResolver { this.recommendListOI = (ListObjectInspector) parameters[0]; this.truthListOI = (ListObjectInspector) parameters[1]; if (parameters.length == 3) { - this.recommendSizeOI = (WritableIntObjectInspector) parameters[2]; + this.recommendSizeOI = HiveUtils.asIntegerOI(parameters[2]); } } else {// from partial aggregation StructObjectInspector soi = (StructObjectInspector) parameters[0]; @@ -159,12 +160,12 @@ public final class MAPUDAF extends AbstractGenericUDAFResolver { int recommendSize = recommendList.size(); if (parameters.length == 3) { - recommendSize = recommendSizeOI.get(parameters[2]); - } - if (recommendSize < 0 || recommendSize > recommendList.size()) { - throw new UDFArgumentException( - "The third argument `int recommendSize` must be in [0, " + recommendList.size() - + "]"); + recommendSize = PrimitiveObjectInspectorUtils.getInt(parameters[2], recommendSizeOI); + if (recommendSize < 0) { + throw new UDFArgumentException( + "The third argument `int recommendSize` must be in greather than or equals to 0: " + + recommendSize); + } } myAggr.iterate(recommendList, truthList, recommendSize); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/MRRUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/evaluation/MRRUDAF.java b/core/src/main/java/hivemall/evaluation/MRRUDAF.java index f5aba3b..98b8c3d 100644 --- a/core/src/main/java/hivemall/evaluation/MRRUDAF.java +++ b/core/src/main/java/hivemall/evaluation/MRRUDAF.java @@ -38,10 +38,11 @@ import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.io.LongWritable; @@ -80,7 +81,7 @@ public final class MRRUDAF extends AbstractGenericUDAFResolver { private ListObjectInspector recommendListOI; private ListObjectInspector truthListOI; - private WritableIntObjectInspector recommendSizeOI; + private PrimitiveObjectInspector recommendSizeOI; private StructObjectInspector internalMergeOI; private StructField countField; @@ -98,7 +99,7 @@ public final class MRRUDAF extends AbstractGenericUDAFResolver { this.recommendListOI = (ListObjectInspector) parameters[0]; this.truthListOI = (ListObjectInspector) parameters[1]; if (parameters.length == 3) { - this.recommendSizeOI = (WritableIntObjectInspector) parameters[2]; + this.recommendSizeOI = HiveUtils.asIntegerOI(parameters[2]); } } else {// from partial aggregation StructObjectInspector soi = (StructObjectInspector) parameters[0]; @@ -159,12 +160,12 @@ public final class MRRUDAF extends AbstractGenericUDAFResolver { int recommendSize = recommendList.size(); if (parameters.length == 3) { - recommendSize = recommendSizeOI.get(parameters[2]); - } - if (recommendSize < 0 || recommendSize > recommendList.size()) { - throw new UDFArgumentException( - "The third argument `int recommendSize` must be in [0, " + recommendList.size() - + "]"); + recommendSize = PrimitiveObjectInspectorUtils.getInt(parameters[2], recommendSizeOI); + if (recommendSize < 0) { + throw new UDFArgumentException( + "The third argument `int recommendSize` must be in greather than or equals to 0: " + + recommendSize); + } } myAggr.iterate(recommendList, truthList, recommendSize); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/NDCGUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/evaluation/NDCGUDAF.java b/core/src/main/java/hivemall/evaluation/NDCGUDAF.java index f1ba832..4e4fde6 100644 --- a/core/src/main/java/hivemall/evaluation/NDCGUDAF.java +++ b/core/src/main/java/hivemall/evaluation/NDCGUDAF.java @@ -45,7 +45,6 @@ import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.io.LongWritable; @@ -85,7 +84,7 @@ public final class NDCGUDAF extends AbstractGenericUDAFResolver { private ListObjectInspector recommendListOI; private ListObjectInspector truthListOI; - private WritableIntObjectInspector recommendSizeOI; + private PrimitiveObjectInspector recommendSizeOI; private StructObjectInspector internalMergeOI; private StructField countField; @@ -103,7 +102,7 @@ public final class NDCGUDAF extends AbstractGenericUDAFResolver { this.recommendListOI = (ListObjectInspector) parameters[0]; this.truthListOI = (ListObjectInspector) parameters[1]; if (parameters.length == 3) { - this.recommendSizeOI = (WritableIntObjectInspector) parameters[2]; + this.recommendSizeOI = HiveUtils.asIntegerOI(parameters[2]); } } else {// from partial aggregation StructObjectInspector soi = (StructObjectInspector) parameters[0]; @@ -164,12 +163,12 @@ public final class NDCGUDAF extends AbstractGenericUDAFResolver { int recommendSize = recommendList.size(); if (parameters.length == 3) { - recommendSize = recommendSizeOI.get(parameters[2]); - } - if (recommendSize < 0 || recommendSize > recommendList.size()) { - throw new UDFArgumentException( - "The third argument `int recommendSize` must be in [0, " + recommendList.size() - + "]"); + recommendSize = PrimitiveObjectInspectorUtils.getInt(parameters[2], recommendSizeOI); + if (recommendSize < 0) { + throw new UDFArgumentException( + "The third argument `int recommendSize` must be in greather than or equals to 0: " + + recommendSize); + } } boolean isBinary = !HiveUtils.isStructOI(recommendListOI.getListElementObjectInspector()); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/PrecisionUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/evaluation/PrecisionUDAF.java b/core/src/main/java/hivemall/evaluation/PrecisionUDAF.java index 93af519..de8a876 100644 --- a/core/src/main/java/hivemall/evaluation/PrecisionUDAF.java +++ b/core/src/main/java/hivemall/evaluation/PrecisionUDAF.java @@ -38,10 +38,11 @@ import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.io.LongWritable; @@ -80,7 +81,7 @@ public final class PrecisionUDAF extends AbstractGenericUDAFResolver { private ListObjectInspector recommendListOI; private ListObjectInspector truthListOI; - private WritableIntObjectInspector recommendSizeOI; + private PrimitiveObjectInspector recommendSizeOI; private StructObjectInspector internalMergeOI; private StructField countField; @@ -98,7 +99,7 @@ public final class PrecisionUDAF extends AbstractGenericUDAFResolver { this.recommendListOI = (ListObjectInspector) parameters[0]; this.truthListOI = (ListObjectInspector) parameters[1]; if (parameters.length == 3) { - this.recommendSizeOI = (WritableIntObjectInspector) parameters[2]; + this.recommendSizeOI = HiveUtils.asIntegerOI(parameters[2]); } } else {// from partial aggregation StructObjectInspector soi = (StructObjectInspector) parameters[0]; @@ -117,9 +118,10 @@ public final class PrecisionUDAF extends AbstractGenericUDAFResolver { return outputOI; } + @Nonnull private static StructObjectInspector internalMergeOI() { - ArrayList<String> fieldNames = new ArrayList<String>(); - ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + List<String> fieldNames = new ArrayList<>(); + List<ObjectInspector> fieldOIs = new ArrayList<>(); fieldNames.add("sum"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); @@ -159,12 +161,12 @@ public final class PrecisionUDAF extends AbstractGenericUDAFResolver { int recommendSize = recommendList.size(); if (parameters.length == 3) { - recommendSize = recommendSizeOI.get(parameters[2]); - } - if (recommendSize < 0 || recommendSize > recommendList.size()) { - throw new UDFArgumentException( - "The third argument `int recommendSize` must be in [0, " + recommendList.size() - + "]"); + recommendSize = PrimitiveObjectInspectorUtils.getInt(parameters[2], recommendSizeOI); + if (recommendSize < 0) { + throw new UDFArgumentException( + "The third argument `int recommendSize` must be in greather than or equals to 0: " + + recommendSize); + } } myAggr.iterate(recommendList, truthList, recommendSize); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/evaluation/RecallUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/evaluation/RecallUDAF.java b/core/src/main/java/hivemall/evaluation/RecallUDAF.java index fed9f71..30b1712 100644 --- a/core/src/main/java/hivemall/evaluation/RecallUDAF.java +++ b/core/src/main/java/hivemall/evaluation/RecallUDAF.java @@ -38,10 +38,11 @@ import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.io.LongWritable; @@ -80,7 +81,7 @@ public final class RecallUDAF extends AbstractGenericUDAFResolver { private ListObjectInspector recommendListOI; private ListObjectInspector truthListOI; - private WritableIntObjectInspector recommendSizeOI; + private PrimitiveObjectInspector recommendSizeOI; private StructObjectInspector internalMergeOI; private StructField countField; @@ -98,7 +99,7 @@ public final class RecallUDAF extends AbstractGenericUDAFResolver { this.recommendListOI = (ListObjectInspector) parameters[0]; this.truthListOI = (ListObjectInspector) parameters[1]; if (parameters.length == 3) { - this.recommendSizeOI = (WritableIntObjectInspector) parameters[2]; + this.recommendSizeOI = HiveUtils.asIntegerOI(parameters[2]); } } else {// from partial aggregation StructObjectInspector soi = (StructObjectInspector) parameters[0]; @@ -159,12 +160,12 @@ public final class RecallUDAF extends AbstractGenericUDAFResolver { int recommendSize = recommendList.size(); if (parameters.length == 3) { - recommendSize = recommendSizeOI.get(parameters[2]); - } - if (recommendSize < 0 || recommendSize > recommendList.size()) { - throw new UDFArgumentException( - "The third argument `int recommendSize` must be in [0, " + recommendList.size() - + "]"); + recommendSize = PrimitiveObjectInspectorUtils.getInt(parameters[2], recommendSizeOI); + if (recommendSize < 0) { + throw new UDFArgumentException( + "The third argument `int recommendSize` must be in greather than or equals to 0: " + + recommendSize); + } } myAggr.iterate(recommendList, truthList, recommendSize); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java index d2232b2..f8eb02f 100644 --- a/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java +++ b/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java @@ -31,6 +31,8 @@ import javax.annotation.Nonnegative; import javax.annotation.Nonnull; /** + * Compressed Sparse Column matrix optimized for colum major access. + * * @link http://netlib.org/linalg/html_templates/node92.html#SECTION00931200000000000000 */ public final class CSCMatrix extends ColumnMajorMatrix { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java index dd89521..805bbd1 100644 --- a/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java +++ b/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java @@ -29,8 +29,8 @@ import javax.annotation.Nonnegative; import javax.annotation.Nonnull; /** - * Read-only CSR double Matrix. - * + * Compressed Sparse Row Matrix optimized for row major access. + * * @link http://netlib.org/linalg/html_templates/node91.html#SECTION00931100000000000000 * @link http://www.cs.colostate.edu/~mcrob/toolbox/c++/sparseMatrix/sparse_matrix_compression.html */ http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/math/matrix/sparse/DoKFloatMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/sparse/DoKFloatMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/DoKFloatMatrix.java new file mode 100644 index 0000000..16b4b64 --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/sparse/DoKFloatMatrix.java @@ -0,0 +1,368 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix.sparse; + +import hivemall.annotations.Experimental; +import hivemall.math.matrix.AbstractMatrix; +import hivemall.math.matrix.ColumnMajorMatrix; +import hivemall.math.matrix.RowMajorMatrix; +import hivemall.math.matrix.builders.DoKMatrixBuilder; +import hivemall.math.vector.Vector; +import hivemall.math.vector.VectorProcedure; +import hivemall.utils.collections.maps.Long2FloatOpenHashTable; +import hivemall.utils.collections.maps.Long2FloatOpenHashTable.IMapIterator; +import hivemall.utils.lang.Preconditions; +import hivemall.utils.lang.Primitives; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +/** + * Dictionary Of Keys based sparse matrix. + * + * This is an efficient structure for constructing a sparse matrix incrementally. + */ +@Experimental +public final class DoKFloatMatrix extends AbstractMatrix { + + @Nonnull + private final Long2FloatOpenHashTable elements; + @Nonnegative + private int numRows; + @Nonnegative + private int numColumns; + @Nonnegative + private int nnz; + + public DoKFloatMatrix() { + this(0, 0); + } + + public DoKFloatMatrix(@Nonnegative int numRows, @Nonnegative int numCols) { + this(numRows, numCols, 0.05f); + } + + public DoKFloatMatrix(@Nonnegative int numRows, @Nonnegative int numCols, + @Nonnegative float sparsity) { + super(); + Preconditions.checkArgument(sparsity >= 0.f && sparsity <= 1.f, "Invalid Sparsity value: " + + sparsity); + int initialCapacity = Math.max(16384, Math.round(numRows * numCols * sparsity)); + this.elements = new Long2FloatOpenHashTable(initialCapacity); + elements.defaultReturnValue(0.f); + this.numRows = numRows; + this.numColumns = numCols; + this.nnz = 0; + } + + public DoKFloatMatrix(@Nonnegative int initSize) { + super(); + int initialCapacity = Math.max(initSize, 16384); + this.elements = new Long2FloatOpenHashTable(initialCapacity); + elements.defaultReturnValue(0.f); + this.numRows = 0; + this.numColumns = 0; + this.nnz = 0; + } + + @Override + public boolean isSparse() { + return true; + } + + @Override + public boolean isRowMajorMatrix() { + return false; + } + + @Override + public boolean isColumnMajorMatrix() { + return false; + } + + @Override + public boolean readOnly() { + return false; + } + + @Override + public boolean swappable() { + return true; + } + + @Override + public int nnz() { + return nnz; + } + + @Override + public int numRows() { + return numRows; + } + + @Override + public int numColumns() { + return numColumns; + } + + @Override + public int numColumns(@Nonnegative final int row) { + int count = 0; + for (int j = 0; j < numColumns; j++) { + long index = index(row, j); + if (elements.containsKey(index)) { + count++; + } + } + return count; + } + + @Override + public double[] getRow(@Nonnegative final int index) { + double[] dst = row(); + return getRow(index, dst); + } + + @Override + public double[] getRow(@Nonnegative final int row, @Nonnull final double[] dst) { + checkRowIndex(row, numRows); + + final int end = Math.min(dst.length, numColumns); + for (int col = 0; col < end; col++) { + long k = index(row, col); + float v = elements.get(k); + dst[col] = v; + } + + return dst; + } + + @Override + public void getRow(@Nonnegative final int index, @Nonnull final Vector row) { + checkRowIndex(index, numRows); + row.clear(); + + for (int col = 0; col < numColumns; col++) { + long k = index(index, col); + final float v = elements.get(k, 0.f); + if (v != 0.f) { + row.set(col, v); + } + } + } + + @Override + public double get(@Nonnegative final int row, @Nonnegative final int col, + final double defaultValue) { + return get(row, col, (float) defaultValue); + } + + public float get(@Nonnegative final int row, @Nonnegative final int col, + final float defaultValue) { + long index = index(row, col); + return elements.get(index, defaultValue); + } + + @Override + public void set(@Nonnegative final int row, @Nonnegative final int col, final double value) { + set(row, col, (float) value); + } + + public void set(@Nonnegative final int row, @Nonnegative final int col, final float value) { + checkIndex(row, col); + + final long index = index(row, col); + if (value == 0.f && elements.containsKey(index) == false) { + return; + } + + if (elements.put(index, value, 0.f) == 0.f) { + nnz++; + this.numRows = Math.max(numRows, row + 1); + this.numColumns = Math.max(numColumns, col + 1); + } + } + + @Override + public double getAndSet(@Nonnegative final int row, @Nonnegative final int col, + final double value) { + return getAndSet(row, col, (float) value); + } + + public float getAndSet(@Nonnegative final int row, @Nonnegative final int col, final float value) { + checkIndex(row, col); + + final long index = index(row, col); + if (value == 0.f && elements.containsKey(index) == false) { + return 0.f; + } + + final float old = elements.put(index, value, 0.f); + if (old == 0.f) { + nnz++; + this.numRows = Math.max(numRows, row + 1); + this.numColumns = Math.max(numColumns, col + 1); + } + return old; + } + + @Override + public void swap(@Nonnegative final int row1, @Nonnegative final int row2) { + checkRowIndex(row1, numRows); + checkRowIndex(row2, numRows); + + for (int j = 0; j < numColumns; j++) { + final long i1 = index(row1, j); + final long i2 = index(row2, j); + + final int k1 = elements._findKey(i1); + final int k2 = elements._findKey(i2); + + if (k1 >= 0) { + if (k2 >= 0) { + float v1 = elements._get(k1); + float v2 = elements._set(k2, v1); + elements._set(k1, v2); + } else {// k1>=0 and k2<0 + float v1 = elements._remove(k1); + elements.put(i2, v1); + } + } else if (k2 >= 0) {// k2>=0 and k1 < 0 + float v2 = elements._remove(k2); + elements.put(i1, v2); + } else {//k1<0 and k2<0 + continue; + } + } + } + + @Override + public void eachInRow(@Nonnegative final int row, @Nonnull final VectorProcedure procedure, + final boolean nullOutput) { + checkRowIndex(row, numRows); + + for (int col = 0; col < numColumns; col++) { + long i = index(row, col); + final int key = elements._findKey(i); + if (key < 0) { + if (nullOutput) { + procedure.apply(col, 0.d); + } + } else { + float v = elements._get(key); + procedure.apply(col, v); + } + } + } + + @Override + public void eachNonZeroInRow(@Nonnegative final int row, + @Nonnull final VectorProcedure procedure) { + checkRowIndex(row, numRows); + + for (int col = 0; col < numColumns; col++) { + long i = index(row, col); + final float v = elements.get(i, 0.f); + if (v != 0.f) { + procedure.apply(col, v); + } + } + } + + @Override + public void eachColumnIndexInRow(int row, VectorProcedure procedure) { + checkRowIndex(row, numRows); + + for (int col = 0; col < numColumns; col++) { + long i = index(row, col); + final int key = elements._findKey(i); + if (key != -1) { + procedure.apply(col); + } + } + } + + @Override + public void eachInColumn(@Nonnegative final int col, @Nonnull final VectorProcedure procedure, + final boolean nullOutput) { + checkColIndex(col, numColumns); + + for (int row = 0; row < numRows; row++) { + long i = index(row, col); + final int key = elements._findKey(i); + if (key < 0) { + if (nullOutput) { + procedure.apply(row, 0.d); + } + } else { + float v = elements._get(key); + procedure.apply(row, v); + } + } + } + + @Override + public void eachNonZeroInColumn(@Nonnegative final int col, + @Nonnull final VectorProcedure procedure) { + checkColIndex(col, numColumns); + + for (int row = 0; row < numRows; row++) { + long i = index(row, col); + final float v = elements.get(i, 0.f); + if (v != 0.f) { + procedure.apply(row, v); + } + } + } + + public void eachNonZeroCell(@Nonnull final VectorProcedure procedure) { + if (nnz == 0) { + return; + } + final IMapIterator itor = elements.entries(); + while (itor.next() != -1) { + long k = itor.getKey(); + int row = Primitives.getHigh(k); + int col = Primitives.getLow(k); + float value = itor.getValue(); + procedure.apply(row, col, value); + } + } + + @Override + public RowMajorMatrix toRowMajorMatrix() { + throw new UnsupportedOperationException("Not yet supported"); + } + + @Override + public ColumnMajorMatrix toColumnMajorMatrix() { + throw new UnsupportedOperationException("Not yet supported"); + } + + @Override + public DoKMatrixBuilder builder() { + return new DoKMatrixBuilder(elements.size()); + } + + @Nonnegative + private static long index(@Nonnegative final int row, @Nonnegative final int col) { + return Primitives.toLong(row, col); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java index bcfd152..054d62a 100644 --- a/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java +++ b/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java @@ -26,12 +26,18 @@ import hivemall.math.matrix.builders.DoKMatrixBuilder; import hivemall.math.vector.Vector; import hivemall.math.vector.VectorProcedure; import hivemall.utils.collections.maps.Long2DoubleOpenHashTable; +import hivemall.utils.collections.maps.Long2DoubleOpenHashTable.IMapIterator; import hivemall.utils.lang.Preconditions; import hivemall.utils.lang.Primitives; import javax.annotation.Nonnegative; import javax.annotation.Nonnull; +/** + * Dictionary Of Keys based sparse matrix. + * + * This is an efficient structure for constructing a sparse matrix incrementally. + */ @Experimental public final class DoKMatrix extends AbstractMatrix { @@ -163,8 +169,6 @@ public final class DoKMatrix extends AbstractMatrix { @Override public double get(@Nonnegative final int row, @Nonnegative final int col, final double defaultValue) { - checkIndex(row, col, numRows, numColumns); - long index = index(row, col); return elements.get(index, defaultValue); } @@ -173,11 +177,11 @@ public final class DoKMatrix extends AbstractMatrix { public void set(@Nonnegative final int row, @Nonnegative final int col, final double value) { checkIndex(row, col); - if (value == 0.d) { + final long index = index(row, col); + if (value == 0.d && elements.containsKey(index) == false) { return; } - long index = index(row, col); if (elements.put(index, value, 0.d) == 0.d) { nnz++; this.numRows = Math.max(numRows, row + 1); @@ -190,8 +194,12 @@ public final class DoKMatrix extends AbstractMatrix { final double value) { checkIndex(row, col); - long index = index(row, col); - double old = elements.put(index, value, 0.d); + final long index = index(row, col); + if (value == 0.d && elements.containsKey(index) == false) { + return 0.d; + } + + final double old = elements.put(index, value, 0.d); if (old == 0.d) { nnz++; this.numRows = Math.max(numRows, row + 1); @@ -309,6 +317,20 @@ public final class DoKMatrix extends AbstractMatrix { } } + public void eachNonZeroCell(@Nonnull final VectorProcedure procedure) { + if (nnz == 0) { + return; + } + final IMapIterator itor = elements.entries(); + while (itor.next() != -1) { + long k = itor.getKey(); + int row = Primitives.getHigh(k); + int col = Primitives.getLow(k); + double value = itor.getValue(); + procedure.apply(row, col, value); + } + } + @Override public RowMajorMatrix toRowMajorMatrix() { throw new UnsupportedOperationException("Not yet supported"); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/math/vector/VectorProcedure.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/vector/VectorProcedure.java b/core/src/main/java/hivemall/math/vector/VectorProcedure.java index 266c531..3f3c390 100644 --- a/core/src/main/java/hivemall/math/vector/VectorProcedure.java +++ b/core/src/main/java/hivemall/math/vector/VectorProcedure.java @@ -24,6 +24,12 @@ public abstract class VectorProcedure { public VectorProcedure() {} + public void apply(@Nonnegative int i, @Nonnegative int j, float value) { + apply(i, j, (double) value); + } + + public void apply(@Nonnegative int i, @Nonnegative int j, double value) {} + public void apply(@Nonnegative int i, double value) {} public void apply(@Nonnegative int i, int value) {} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java b/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java index 141b261..0f9b5fd 100644 --- a/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java +++ b/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java @@ -512,9 +512,8 @@ public final class BPRMatrixFactorizationUDTF extends UDTFWithOptions implements // write training examples in buffer to a temporary file if (inputBuf.position() > 0) { writeBuffer(inputBuf, fileIO, lastWritePos); - } else if (lastWritePos == 0) { - return; // no training example } + try { fileIO.flush(); } catch (IOException e) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java b/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java index 66ec60d..ee549c5 100644 --- a/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java +++ b/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java @@ -148,7 +148,7 @@ public abstract class OnlineMatrixFactorizationUDTF extends UDTFWithOptions impl this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), 1); if (iterations < 1) { throw new UDFArgumentException( - "'-iterations' must be greater than or equals to 1: " + iterations); + "'-iterations' must be greater than or equal to 1: " + iterations); } conversionCheck = !cl.hasOption("disable_cvtest"); convergenceRate = Primitives.parseDouble(cl.getOptionValue("cv_rate"), convergenceRate); @@ -239,7 +239,7 @@ public abstract class OnlineMatrixFactorizationUDTF extends UDTFWithOptions impl } int item = PrimitiveObjectInspectorUtils.getInt(args[1], itemOI); if (item < 0) { - throw new HiveException("Illegal item index: " + user); + throw new HiveException("Illegal item index: " + item); } double rating = PrimitiveObjectInspectorUtils.getDouble(args[2], ratingOI); @@ -505,9 +505,8 @@ public abstract class OnlineMatrixFactorizationUDTF extends UDTFWithOptions impl // write training examples in buffer to a temporary file if (inputBuf.position() > 0) { writeBuffer(inputBuf, fileIO, lastWritePos); - } else if (lastWritePos == 0) { - return; // no training example } + try { fileIO.flush(); } catch (IOException e) {
