Close #123: [HIVEMALL-154] Refactor Field-aware Factorization Machines to support Instance-wise L2 normalization
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/ad15923a Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/ad15923a Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/ad15923a Branch: refs/heads/master Commit: ad15923a1f88ee10b1d66efa92f5f79bbc6ea804 Parents: f469cce Author: Makoto Yui <[email protected]> Authored: Wed Oct 25 00:40:54 2017 +0900 Committer: Makoto Yui <[email protected]> Committed: Wed Oct 25 00:40:54 2017 +0900 ---------------------------------------------------------------------- NOTICE | 15 +- core/pom.xml | 11 +- .../src/main/java/hivemall/LearnerBaseUDTF.java | 14 +- .../KernelExpansionPassiveAggressiveUDTF.java | 34 +- .../java/hivemall/common/ConversionState.java | 45 +- .../hivemall/fm/FFMStringFeatureMapModel.java | 61 +-- .../java/hivemall/fm/FMHyperParameters.java | 7 +- .../java/hivemall/fm/FMIntFeatureMapModel.java | 14 +- .../hivemall/fm/FMStringFeatureMapModel.java | 14 +- .../hivemall/fm/FactorizationMachineUDTF.java | 48 +- core/src/main/java/hivemall/fm/Feature.java | 55 +-- .../fm/FieldAwareFactorizationMachineUDTF.java | 24 +- core/src/main/java/hivemall/fm/IntFeature.java | 6 +- .../ftvec/ranking/PositiveOnlyFeedback.java | 19 +- .../builders/ColumnMajorDenseMatrixBuilder.java | 16 +- .../main/java/hivemall/mf/FactorizedModel.java | 46 +- .../hivemall/model/AbstractPredictionModel.java | 14 +- .../java/hivemall/model/NewSparseModel.java | 6 +- .../main/java/hivemall/model/SparseModel.java | 6 +- .../optimizer/DenseOptimizerFactory.java | 4 +- .../main/java/hivemall/optimizer/Optimizer.java | 8 +- .../optimizer/SparseOptimizerFactory.java | 31 +- .../main/java/hivemall/recommend/SlimUDTF.java | 101 ++-- .../tools/mapred/DistributedCacheLookupUDF.java | 10 +- .../maps/Int2DoubleOpenHashTable.java | 427 ---------------- .../maps/Int2FloatOpenHashTable.java | 430 ---------------- .../collections/maps/Int2IntOpenHashTable.java | 422 ---------------- .../collections/maps/Int2LongOpenHashMap.java | 346 ------------- .../collections/maps/Int2LongOpenHashTable.java | 494 ------------------- .../collections/maps/IntOpenHashTable.java | 485 ------------------ .../hivemall/utils/collections/maps/LRUMap.java | 41 -- .../maps/Long2DoubleOpenHashTable.java | 321 ++++++------ .../maps/Long2FloatOpenHashTable.java | 319 ++++++------ .../collections/maps/Long2IntOpenHashTable.java | 354 +++++++------ .../utils/collections/maps/OpenHashMap.java | 409 --------------- .../utils/collections/maps/OpenHashTable.java | 337 +++++++------ .../java/hivemall/utils/lambda/Throwing.java | 46 ++ .../hivemall/utils/lambda/ThrowingConsumer.java | 37 ++ .../main/java/hivemall/utils/math/FastMath.java | 466 +++++++++++++++++ .../java/hivemall/utils/math/MathUtils.java | 20 +- .../fm/FactorizationMachineUDTFTest.java | 17 + core/src/test/java/hivemall/fm/FeatureTest.java | 75 ++- .../FieldAwareFactorizationMachineUDTFTest.java | 18 +- .../maps/Int2FloatOpenHashTableTest.java | 98 ---- .../maps/Int2LongOpenHashMapTest.java | 106 ---- .../maps/Int2LongOpenHashTableTest.java | 130 ----- .../collections/maps/IntOpenHashTableTest.java | 75 --- .../maps/Long2DoubleOpenHashTableTest.java | 187 +++++++ .../maps/Long2FloatOpenHashTableTest.java | 187 +++++++ .../maps/Long2IntOpenHashTableTest.java | 95 +++- .../utils/collections/maps/OpenHashMapTest.java | 93 ---- .../collections/maps/OpenHashTableTest.java | 62 ++- .../hivemall/utils/lambda/ThrowingTest.java | 66 +++ .../java/hivemall/utils/math/FastMathTest.java | 109 ++++ mixserv/pom.xml | 2 +- nlp/pom.xml | 2 +- pom.xml | 69 ++- spark/spark-2.0/pom.xml | 2 +- spark/spark-2.1/pom.xml | 2 +- spark/spark-2.2/pom.xml | 2 +- xgboost/pom.xml | 2 +- 61 files changed, 2460 insertions(+), 4502 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/NOTICE ---------------------------------------------------------------------- diff --git a/NOTICE b/NOTICE index 79a944c..bfc4af8 100644 --- a/NOTICE +++ b/NOTICE @@ -33,6 +33,17 @@ o hivemall/core/src/main/java/hivemall/utils/collections/OpenHashMap.java https://github.com/slipperyseal/atomicobjects/ Licensed under the Apache License, Version 2.0 +o hivemall/core/src/main/java/hivemall/utils/math/FastMath.java + + Copyright 2012-2015 Jeff Hain + + https://github.com/jeffhain/jafama/ + Licensed under the Apache License, Version 2.0 + + Copyright (C) 1993 by Sun Microsystems, Inc. + + Permission to use, copy, modify, and distribute this software is freely granted, provided that this notice is preserved. + ------------------------------------------------------------------------------------------------------ Copyright notifications which have been relocated from ASF projects @@ -50,7 +61,7 @@ o hivemall/core/src/main/java/hivemall/utils/buffer/DynamicByteArray.java https://orc.apache.org/ Licensed under the Apache License, Version 2.0 - hivemall/spark/spark-2.0/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +o hivemall/spark/spark-2.0/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala hivemall/spark/spark-2.0/src/test/scala/org/apache/spark/sql/QueryTest.scala hivemall/spark/spark-2.0/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala hivemall/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala @@ -66,3 +77,5 @@ o hivemall/core/src/main/java/hivemall/utils/buffer/DynamicByteArray.java http://spark.apache.org/ Licensed under the Apache License, Version 2.0 + + \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/pom.xml ---------------------------------------------------------------------- diff --git a/core/pom.xml b/core/pom.xml index 9368993..b440946 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -139,7 +139,13 @@ <dependency> <groupId>org.roaringbitmap</groupId> <artifactId>RoaringBitmap</artifactId> - <version>[0.6,)</version> + <version>[0.6,0.7)</version> + <scope>compile</scope> + </dependency> + <dependency> + <groupId>it.unimi.dsi</groupId> + <artifactId>fastutil</artifactId> + <version>[8.1.0,8.2)</version> <scope>compile</scope> </dependency> @@ -190,7 +196,7 @@ <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-shade-plugin</artifactId> - <version>2.3</version> + <version>3.1.0</version> <executions> <execution> <id>jar-with-dependencies</id> @@ -212,6 +218,7 @@ <include>org.tukaani:xz</include> <include>org.apache.commons:commons-math3</include> <include>org.roaringbitmap:RoaringBitmap</include> + <include>it.unimi.dsi:fastutil</include> </includes> </artifactSet> <transformers> http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/main/java/hivemall/LearnerBaseUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/LearnerBaseUDTF.java b/core/src/main/java/hivemall/LearnerBaseUDTF.java index b9ec668..356d739 100644 --- a/core/src/main/java/hivemall/LearnerBaseUDTF.java +++ b/core/src/main/java/hivemall/LearnerBaseUDTF.java @@ -55,6 +55,8 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn public abstract class LearnerBaseUDTF extends UDTFWithOptions { private static final Log logger = LogFactory.getLog(LearnerBaseUDTF.class); + private static final int DEFAULT_SPARSE_DIMS = 16384; + private static final int DEFAULT_DENSE_DIMS = 16777216; protected final boolean enableNewModel; protected boolean dense_model; @@ -120,7 +122,7 @@ public abstract class LearnerBaseUDTF extends UDTFWithOptions { denseModel = cl.hasOption("dense"); if (denseModel) { - modelDims = Primitives.parseInt(cl.getOptionValue("dims"), 16777216); + modelDims = Primitives.parseInt(cl.getOptionValue("dims"), DEFAULT_DENSE_DIMS); } disableHalfFloat = cl.hasOption("disable_halffloat"); @@ -168,7 +170,7 @@ public abstract class LearnerBaseUDTF extends UDTFWithOptions { PredictionModel model; final boolean useCovar = useCovariance(); if (dense_model) { - if (disable_halffloat == false && model_dims > 16777216) { + if (disable_halffloat == false && model_dims > DEFAULT_DENSE_DIMS) { logger.info("Build a space efficient dense model with " + model_dims + " initial dimensions" + (useCovar ? " w/ covariances" : "")); model = new SpaceEfficientDenseModel(model_dims, useCovar); @@ -199,7 +201,7 @@ public abstract class LearnerBaseUDTF extends UDTFWithOptions { PredictionModel model; final boolean useCovar = useCovariance(); if (dense_model) { - if (disable_halffloat == false && model_dims > 16777216) { + if (disable_halffloat == false && model_dims > DEFAULT_DENSE_DIMS) { logger.info("Build a space efficient dense model with " + model_dims + " initial dimensions" + (useCovar ? " w/ covariances" : "")); model = new NewSpaceEfficientDenseModel(model_dims, useCovar); @@ -229,9 +231,11 @@ public abstract class LearnerBaseUDTF extends UDTFWithOptions { protected final Optimizer createOptimizer(@CheckForNull Map<String, String> options) { Preconditions.checkNotNull(options); if (dense_model) { - return DenseOptimizerFactory.create(model_dims, options); + return DenseOptimizerFactory.create(model_dims < 0 ? DEFAULT_DENSE_DIMS : model_dims, + options); } else { - return SparseOptimizerFactory.create(model_dims, options); + return SparseOptimizerFactory.create(model_dims < 0 ? DEFAULT_SPARSE_DIMS : model_dims, + options); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/main/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTF.java b/core/src/main/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTF.java index 6e6a2a0..c3a1371 100644 --- a/core/src/main/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTF.java +++ b/core/src/main/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTF.java @@ -23,11 +23,12 @@ import hivemall.annotations.VisibleForTesting; import hivemall.model.FeatureValue; import hivemall.model.PredictionModel; import hivemall.model.PredictionResult; -import hivemall.utils.collections.maps.Int2FloatOpenHashTable; -import hivemall.utils.collections.maps.Int2FloatOpenHashTable.IMapIterator; import hivemall.optimizer.LossFunctions; import hivemall.utils.hashing.HashFunction; import hivemall.utils.lang.Preconditions; +import it.unimi.dsi.fastutil.ints.Int2FloatMap; +import it.unimi.dsi.fastutil.ints.Int2FloatMaps; +import it.unimi.dsi.fastutil.ints.Int2FloatOpenHashMap; import java.util.ArrayList; import java.util.List; @@ -72,9 +73,9 @@ public final class KernelExpansionPassiveAggressiveUDTF extends BinaryOnlineClas // Model parameters private float _w0; - private Int2FloatOpenHashTable _w1; - private Int2FloatOpenHashTable _w2; - private Int2FloatOpenHashTable _w3; + private Int2FloatMap _w1; + private Int2FloatMap _w2; + private Int2FloatMap _w3; // ------------------------------------ @@ -182,11 +183,11 @@ public final class KernelExpansionPassiveAggressiveUDTF extends BinaryOnlineClas @Override protected PredictionModel createModel() { this._w0 = 0.f; - this._w1 = new Int2FloatOpenHashTable(16384); + this._w1 = new Int2FloatOpenHashMap(16384); _w1.defaultReturnValue(0.f); - this._w2 = new Int2FloatOpenHashTable(16384); + this._w2 = new Int2FloatOpenHashMap(16384); _w2.defaultReturnValue(0.f); - this._w3 = new Int2FloatOpenHashTable(16384); + this._w3 = new Int2FloatOpenHashMap(16384); _w3.defaultReturnValue(0.f); return null; @@ -351,13 +352,12 @@ public final class KernelExpansionPassiveAggressiveUDTF extends BinaryOnlineClas row[2] = w1; row[3] = w2; - final Int2FloatOpenHashTable w2map = _w2; - final IMapIterator w1itor = _w1.entries(); - while (w1itor.next() != -1) { - int k = w1itor.getKey(); + final Int2FloatMap w2map = _w2; + for (Int2FloatMap.Entry e : Int2FloatMaps.fastIterable(_w1)) { + int k = e.getIntKey(); Preconditions.checkArgument(k > 0, HiveException.class); h.set(k); - w1.set(w1itor.getValue()); + w1.set(e.getFloatValue()); w2.set(w2map.get(k)); forward(row); // h(f), w1, w2 } @@ -369,12 +369,12 @@ public final class KernelExpansionPassiveAggressiveUDTF extends BinaryOnlineClas row[3] = null; row[4] = hk; row[5] = w3; - final IMapIterator w3itor = _w3.entries(); - while (w3itor.next() != -1) { - int k = w3itor.getKey(); + + for (Int2FloatMap.Entry e : Int2FloatMaps.fastIterable(_w3)) { + int k = e.getIntKey(); Preconditions.checkArgument(k > 0, HiveException.class); hk.set(k); - w3.set(w3itor.getValue()); + w3.set(e.getFloatValue()); forward(row); // hk(f), w3 } this._w3 = null; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/main/java/hivemall/common/ConversionState.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/common/ConversionState.java b/core/src/main/java/hivemall/common/ConversionState.java index 435bf75..fc222d6 100644 --- a/core/src/main/java/hivemall/common/ConversionState.java +++ b/core/src/main/java/hivemall/common/ConversionState.java @@ -18,6 +18,9 @@ */ package hivemall.common; +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -61,6 +64,13 @@ public final class ConversionState { return currLosses; } + public double getAverageLoss(@Nonnegative final long numInstances) { + if (numInstances == 0) { + return 0.d; + } + return currLosses / numInstances; + } + public double getPreviousLoss() { return prevLosses; } @@ -88,36 +98,31 @@ public final class ConversionState { if (currLosses > prevLosses) { if (logger.isInfoEnabled()) { - logger.info("Iteration #" + curIter + " currLoss `" + currLosses - + "` > prevLosses `" + prevLosses + '`'); + logger.info("Iteration #" + curIter + " current cumulative loss `" + currLosses + + "` > previous cumulative loss `" + prevLosses + '`'); } this.readyToFinishIterations = false; return false; } - final double changeRate = (prevLosses - currLosses) / prevLosses; + final double changeRate = getChangeRate(); if (changeRate < convergenceRate) { if (readyToFinishIterations) { // NOTE: never be true at the first iteration where prevLosses == Double.POSITIVE_INFINITY if (logger.isInfoEnabled()) { - logger.info("Training converged at " + curIter + "-th iteration. [curLosses=" - + currLosses + ", prevLosses=" + prevLosses + ", changeRate=" - + changeRate + ']'); + logger.info("Training converged at " + curIter + "-th iteration!\n" + + getInfo(observedTrainingExamples)); } return true; } else { if (logger.isInfoEnabled()) { - logger.info("Iteration #" + curIter + " [curLosses=" + currLosses - + ", prevLosses=" + prevLosses + ", changeRate=" + changeRate - + ", #trainingExamples=" + observedTrainingExamples + ']'); + logger.info(getInfo(observedTrainingExamples)); } this.readyToFinishIterations = true; } } else { if (logger.isInfoEnabled()) { - logger.info("Iteration #" + curIter + " [curLosses=" + currLosses + ", prevLosses=" - + prevLosses + ", changeRate=" + changeRate + ", #trainingExamples=" - + observedTrainingExamples + ']'); + logger.info(getInfo(observedTrainingExamples)); } this.readyToFinishIterations = false; } @@ -125,6 +130,10 @@ public final class ConversionState { return false; } + double getChangeRate() { + return (prevLosses - currLosses) / prevLosses; + } + public void next() { this.prevLosses = currLosses; this.currLosses = 0.d; @@ -135,4 +144,16 @@ public final class ConversionState { return curIter; } + @Nonnull + public String getInfo(@Nonnegative final long observedTrainingExamples) { + final StringBuilder buf = new StringBuilder(); + buf.append("Iteration #").append(curIter).append(" | "); + buf.append("average loss=").append(getAverageLoss(observedTrainingExamples)).append(", "); + buf.append("current cumulative loss=").append(currLosses).append(", "); + buf.append("previous cumulative loss=").append(prevLosses).append(", "); + buf.append("change rate=").append(getChangeRate()).append(", "); + buf.append("#trainingExamples=").append(observedTrainingExamples); + return buf.toString(); + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java b/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java index 22b0541..282dc4e 100644 --- a/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java +++ b/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java @@ -23,9 +23,9 @@ import hivemall.fm.Entry.FTRLEntry; import hivemall.fm.FMHyperParameters.FFMHyperParameters; import hivemall.utils.buffer.HeapBuffer; import hivemall.utils.collections.lists.LongArrayList; -import hivemall.utils.collections.maps.Int2LongOpenHashTable; -import hivemall.utils.collections.maps.Int2LongOpenHashTable.MapIterator; import hivemall.utils.lang.NumberUtils; +import it.unimi.dsi.fastutil.ints.Int2LongMap; +import it.unimi.dsi.fastutil.ints.Int2LongOpenHashMap; import java.text.NumberFormat; import java.util.Locale; @@ -42,9 +42,9 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi // LEARNING PARAMS private float _w0; @Nonnull - private final Int2LongOpenHashTable _map; + final Int2LongMap _map; @Nonnull - private final HeapBuffer _buf; + final HeapBuffer _buf; @Nonnull private final LongArrayList _freelistW; @@ -69,7 +69,8 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi public FFMStringFeatureMapModel(@Nonnull FFMHyperParameters params) { super(params); this._w0 = 0.f; - this._map = new Int2LongOpenHashTable(DEFAULT_MAPSIZE); + this._map = new Int2LongOpenHashMap(DEFAULT_MAPSIZE); + _map.defaultReturnValue(-1L); this._buf = new HeapBuffer(HeapBuffer.DEFAULT_CHUNK_SIZE); this._freelistW = new LongArrayList(); this._freelistV = new LongArrayList(); @@ -326,54 +327,4 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi return getStatistics(); } - @Nonnull - EntryIterator entries() { - return new EntryIterator(this); - } - - static final class EntryIterator { - - @Nonnull - private final MapIterator dictItor; - @Nonnull - private final Entry entryProbeW; - @Nonnull - private final Entry entryProbeV; - - EntryIterator(@Nonnull FFMStringFeatureMapModel model) { - this.dictItor = model._map.entries(); - this.entryProbeW = new Entry(model._buf, 1); - this.entryProbeV = new Entry(model._buf, model._factor); - } - - @Nonnull - Entry getEntryProbeW() { - return entryProbeW; - } - - @Nonnull - Entry getEntryProbeV() { - return entryProbeV; - } - - boolean hasNext() { - return dictItor.hasNext(); - } - - boolean next() { - return dictItor.next() != -1; - } - - int getEntryIndex() { - return dictItor.getKey(); - } - - @Nonnull - void getEntry(@Nonnull final Entry probe) { - long offset = dictItor.getValue(); - probe.setOffset(offset); - } - - } - } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/main/java/hivemall/fm/FMHyperParameters.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FMHyperParameters.java b/core/src/main/java/hivemall/fm/FMHyperParameters.java index 15c1c56..e4254dd 100644 --- a/core/src/main/java/hivemall/fm/FMHyperParameters.java +++ b/core/src/main/java/hivemall/fm/FMHyperParameters.java @@ -60,6 +60,8 @@ class FMHyperParameters { // ------------------------------------- // non-model parameters + boolean l2norm; // enable by default for FFM. disabled by default for FM. + int iters = 1; boolean conversionCheck = true; double convergenceRate = 0.005d; @@ -78,8 +80,8 @@ class FMHyperParameters { + ", lambda=" + lambda + ", lambdaW0=" + lambdaW0 + ", lambdaW=" + lambdaW + ", lambdaV=" + lambdaV + ", sigma=" + sigma + ", seed=" + seed + ", vInit=" + vInit + ", minTarget=" + minTarget + ", maxTarget=" + maxTarget + ", eta=" + eta - + ", numFeatures=" + numFeatures + ", iters=" + iters + ", conversionCheck=" - + conversionCheck + ", convergenceRate=" + convergenceRate + + ", numFeatures=" + numFeatures + ", l2norm=" + l2norm + ", iters=" + iters + + ", conversionCheck=" + conversionCheck + ", convergenceRate=" + convergenceRate + ", adaptiveReglarization=" + adaptiveReglarization + ", validationRatio=" + validationRatio + ", validationThreshold=" + validationThreshold + ", parseFeatureAsInt=" + parseFeatureAsInt + "]"; @@ -102,6 +104,7 @@ class FMHyperParameters { this.maxTarget = Primitives.parseDouble(cl.getOptionValue("max_target"), maxTarget); this.eta = EtaEstimator.get(cl, DEFAULT_ETA0); this.numFeatures = Primitives.parseInt(cl.getOptionValue("num_features"), numFeatures); + this.l2norm = cl.hasOption("enable_norm"); this.iters = Primitives.parseInt(cl.getOptionValue("iterations"), iters); this.conversionCheck = !cl.hasOption("disable_cvtest"); this.convergenceRate = Primitives.parseDouble(cl.getOptionValue("cv_rate"), convergenceRate); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java b/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java index be39b0b..cbb0d70 100644 --- a/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java +++ b/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java @@ -18,8 +18,10 @@ */ package hivemall.fm; -import hivemall.utils.collections.maps.Int2FloatOpenHashTable; -import hivemall.utils.collections.maps.IntOpenHashTable; +import it.unimi.dsi.fastutil.ints.Int2FloatMap; +import it.unimi.dsi.fastutil.ints.Int2FloatOpenHashMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; import java.util.Arrays; @@ -32,17 +34,17 @@ public final class FMIntFeatureMapModel extends FactorizationMachineModel { // LEARNING PARAMS private float _w0; - private final Int2FloatOpenHashTable _w; - private final IntOpenHashTable<float[]> _V; + private final Int2FloatMap _w; + private final Int2ObjectMap<float[]> _V; private int _minIndex, _maxIndex; public FMIntFeatureMapModel(@Nonnull FMHyperParameters params) { super(params); this._w0 = 0.f; - this._w = new Int2FloatOpenHashTable(DEFAULT_MAPSIZE); + this._w = new Int2FloatOpenHashMap(DEFAULT_MAPSIZE); _w.defaultReturnValue(0.f); - this._V = new IntOpenHashTable<float[]>(DEFAULT_MAPSIZE); + this._V = new Int2ObjectOpenHashMap<float[]>(DEFAULT_MAPSIZE); this._minIndex = 0; this._maxIndex = 0; } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java b/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java index 4eec280..84b780a 100644 --- a/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java +++ b/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java @@ -18,8 +18,8 @@ */ package hivemall.fm; -import hivemall.utils.collections.IMapIterator; -import hivemall.utils.collections.maps.OpenHashMap; +import it.unimi.dsi.fastutil.objects.Object2ObjectMap; +import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap; import javax.annotation.Nonnull; @@ -28,12 +28,13 @@ public final class FMStringFeatureMapModel extends FactorizationMachineModel { // LEARNING PARAMS private float _w0; - private final OpenHashMap<String, Entry> _map; + private final Object2ObjectMap<String, Entry> _map; public FMStringFeatureMapModel(@Nonnull FMHyperParameters params) { super(params); this._w0 = 0.f; - this._map = new OpenHashMap<String, FMStringFeatureMapModel.Entry>(DEFAULT_MAPSIZE); + this._map = new Object2ObjectOpenHashMap<String, FMStringFeatureMapModel.Entry>( + DEFAULT_MAPSIZE); } @Override @@ -41,8 +42,9 @@ public final class FMStringFeatureMapModel extends FactorizationMachineModel { return _map.size(); } - IMapIterator<String, Entry> entries() { - return _map.entries(true); + @Nonnull + Object2ObjectMap<String, Entry> getMap() { + return _map; } @Override http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java index 5c8af32..bca1365 100644 --- a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java +++ b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java @@ -18,11 +18,28 @@ */ package hivemall.fm; +import hivemall.UDTFWithOptions; +import hivemall.annotations.VisibleForTesting; +import hivemall.common.ConversionState; +import hivemall.fm.FMStringFeatureMapModel.Entry; +import hivemall.optimizer.EtaEstimator; +import hivemall.optimizer.LossFunctions; +import hivemall.optimizer.LossFunctions.LossFunction; +import hivemall.optimizer.LossFunctions.LossType; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.io.FileUtils; +import hivemall.utils.io.NioStatefullSegment; +import hivemall.utils.lang.NumberUtils; +import hivemall.utils.lang.SizeOf; +import hivemall.utils.math.MathUtils; +import it.unimi.dsi.fastutil.objects.Object2ObjectMaps; + import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; +import java.util.Map; import java.util.Random; import javax.annotation.Nonnull; @@ -48,22 +65,6 @@ import org.apache.hadoop.io.Text; import org.apache.hadoop.mapred.Counters.Counter; import org.apache.hadoop.mapred.Reporter; -import hivemall.UDTFWithOptions; -import hivemall.annotations.VisibleForTesting; -import hivemall.common.ConversionState; -import hivemall.fm.FMStringFeatureMapModel.Entry; -import hivemall.optimizer.EtaEstimator; -import hivemall.optimizer.LossFunctions; -import hivemall.optimizer.LossFunctions.LossFunction; -import hivemall.optimizer.LossFunctions.LossType; -import hivemall.utils.collections.IMapIterator; -import hivemall.utils.hadoop.HiveUtils; -import hivemall.utils.io.FileUtils; -import hivemall.utils.io.NioStatefullSegment; -import hivemall.utils.lang.NumberUtils; -import hivemall.utils.lang.SizeOf; -import hivemall.utils.math.MathUtils; - @Description( name = "train_fm", value = "_FUNC_(array<string> x, double y [, const string options]) - Returns a prediction model") @@ -163,6 +164,8 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { // feature representation opts.addOption("int_feature", "feature_as_integer", false, "Parse a feature as integer [default: OFF]"); + // normalization + opts.addOption("enable_norm", "l2norm", false, "Enable instance-wise L2 normalization"); return opts; } @@ -288,7 +291,11 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { @Nullable protected Feature[] parseFeatures(@Nonnull final Object arg) throws HiveException { - return Feature.parseFeatures(arg, _xOI, _probes, _parseFeatureAsInt); + Feature[] features = Feature.parseFeatures(arg, _xOI, _probes, _parseFeatureAsInt); + if (_params.l2norm) { + Feature.l2normalize(features); + } + return features; } protected void recordTrain(@Nonnull final Feature[] x, final double y) throws HiveException { @@ -509,13 +516,12 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { // Wi, Vif (i starts from 1..P) forwardObjs[2] = Arrays.asList(f_Vi); - final IMapIterator<String, Entry> itor = model.entries(); - while (itor.next() != -1) { - String i = itor.getKey(); + for (Map.Entry<String, Entry> e : Object2ObjectMaps.fastIterable(model.getMap())) { + String i = e.getKey(); assert (i != null); // set i feature.set(i); - Entry entry = itor.getValue(); + Entry entry = e.getValue(); // set Wi f_Wi.set(entry.W); // set Vif http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/main/java/hivemall/fm/Feature.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/Feature.java b/core/src/main/java/hivemall/fm/Feature.java index 8ae6f20..c4915ca 100644 --- a/core/src/main/java/hivemall/fm/Feature.java +++ b/core/src/main/java/hivemall/fm/Feature.java @@ -222,24 +222,11 @@ public abstract class Feature { final String lead = fv.substring(0, pos1); final String rest = fv.substring(pos1 + 1); final int pos2 = rest.indexOf(':'); - if (pos2 == -1) {// e.g., i1:1.0 (quantitative features) - final int index; - if (NumberUtils.isDigits(lead)) { - index = parseFeatureIndex(lead); - if (index < 0 || index >= numFields) { - throw new HiveException("Invalid index value '" + index - + "' for a quantative features: " + fv + ", expecting index less than " - + numFields); - } - } else { - index = MurmurHash3.murmurhash3(lead, numFields); - } - short field = NumberUtils.castToShort(index); - double value = parseFeatureValue(rest); - return new IntFeature(index, field, value); + if (pos2 == -1) { + throw new HiveException( + "Invalid FFM feature repsentation. Expected <field>:<index>:<value> but got " + fv); } - final short field; if (NumberUtils.isDigits(lead)) { field = parseField(lead, numFields); @@ -306,23 +293,9 @@ public abstract class Feature { final String lead = fv.substring(0, pos1); final String rest = fv.substring(pos1 + 1); final int pos2 = rest.indexOf(':'); - if (pos2 == -1) {// e.g., i1:1.0 (quantitative features) expecting |feature| less than 1024 - final int index; - if (NumberUtils.isDigits(lead)) { - index = parseFeatureIndex(lead); - if (index < 0 || index >= numFields) { - throw new HiveException("Invalid index value '" + index - + "' for a quantative features: " + fv + ", expecting index less than " - + numFields); - } - } else { - index = MurmurHash3.murmurhash3(lead, numFields); - } - short field = NumberUtils.castToShort(index); - probe.setField(field); - probe.setFeatureIndex(index); - probe.value = parseFeatureValue(rest); - return; + if (pos2 == -1) { + throw new HiveException( + "Invalid FFM feature repsentation. Expected <field>:<index>:<value> but got " + fv); } final short field; @@ -392,4 +365,20 @@ public abstract class Feature { return index * numFields + yField; } + public static void l2normalize(@Nonnull final Feature[] features) { + double squaredSum = 0.d; + for (Feature f : features) { + double v = f.value; + squaredSum += (v * v); + } + if (squaredSum == 0.d) { + return; + } + + final double invNorm = 1.d / Math.sqrt(squaredSum); + for (Feature f : features) { + f.value *= invNorm; + } + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java index 56d9dc2..d602fd6 100644 --- a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java +++ b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java @@ -18,13 +18,14 @@ */ package hivemall.fm; -import hivemall.fm.FFMStringFeatureMapModel.EntryIterator; import hivemall.fm.FMHyperParameters.FFMHyperParameters; import hivemall.utils.collections.arrays.DoubleArray3D; import hivemall.utils.collections.lists.IntArrayList; import hivemall.utils.hadoop.HadoopUtils; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.math.MathUtils; +import it.unimi.dsi.fastutil.ints.Int2LongMap; +import it.unimi.dsi.fastutil.ints.Int2LongMaps; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -173,7 +174,11 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi @Override protected Feature[] parseFeatures(@Nonnull final Object arg) throws HiveException { - return Feature.parseFFMFeatures(arg, _xOI, _probes, _numFeatures, _numFields); + Feature[] features = Feature.parseFFMFeatures(arg, _xOI, _probes, _numFeatures, _numFields); + if (_params.l2norm) { + Feature.l2normalize(features); + } + return features; } @Override @@ -288,17 +293,18 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi Wi.set(_ffmModel.getW0()); forward(forwardObjs); - final EntryIterator itor = _ffmModel.entries(); - final Entry entryW = itor.getEntryProbeW(); - final Entry entryV = itor.getEntryProbeV(); + final Entry entryW = new Entry(_ffmModel._buf, 1); + final Entry entryV = new Entry(_ffmModel._buf, _ffmModel._factor); final float[] Vf = new float[factors]; - while (itor.next()) { + + for (Int2LongMap.Entry e : Int2LongMaps.fastIterable(_ffmModel._map)) { // set i - int i = itor.getEntryIndex(); + final int i = e.getIntKey(); idx.set(i); + final long offset = e.getLongValue(); if (Entry.isEntryW(i)) {// set Wi - itor.getEntry(entryW); + entryW.setOffset(offset); float w = entryV.getW(); if (w == 0.f) { continue; // skip w_i=0 @@ -307,7 +313,7 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi forwardObjs[2] = Wi; forwardObjs[3] = null; } else {// set Vif - itor.getEntry(entryV); + entryV.setOffset(offset); entryV.getV(Vf); for (int f = 0; f < factors; f++) { Vi[f].set(Vf[f]); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/main/java/hivemall/fm/IntFeature.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/IntFeature.java b/core/src/main/java/hivemall/fm/IntFeature.java index 64a4daa..881c161 100644 --- a/core/src/main/java/hivemall/fm/IntFeature.java +++ b/core/src/main/java/hivemall/fm/IntFeature.java @@ -18,6 +18,8 @@ */ package hivemall.fm; +import hivemall.utils.lang.SizeOf; + import java.nio.ByteBuffer; import javax.annotation.Nonnegative; @@ -72,7 +74,7 @@ public final class IntFeature extends Feature { @Override public int bytes() { - return (Integer.SIZE + Short.SIZE + Double.SIZE) / Byte.SIZE; + return SizeOf.INT + SizeOf.SHORT + SizeOf.DOUBLE; } @Override @@ -94,7 +96,7 @@ public final class IntFeature extends Feature { if (field == -1) { return index + ":" + value; } else { - return index + ":" + field + ":" + value; + return field + ":" + index + ":" + value; } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java b/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java index cdba00b..aab125f 100644 --- a/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java +++ b/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java @@ -19,8 +19,9 @@ package hivemall.ftvec.ranking; import hivemall.utils.collections.lists.IntArrayList; -import hivemall.utils.collections.maps.IntOpenHashTable; -import hivemall.utils.collections.maps.IntOpenHashTable.IMapIterator; +import it.unimi.dsi.fastutil.ints.Int2ObjectMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; +import it.unimi.dsi.fastutil.ints.IntIterator; import java.util.BitSet; @@ -30,13 +31,13 @@ import javax.annotation.Nullable; public class PositiveOnlyFeedback { @Nonnull - protected final IntOpenHashTable<IntArrayList> rows; + protected final Int2ObjectMap<IntArrayList> rows; protected int maxItemId; protected int totalFeedbacks; public PositiveOnlyFeedback(int maxItemId) { - this.rows = new IntOpenHashTable<IntArrayList>(1024); + this.rows = new Int2ObjectOpenHashMap<IntArrayList>(1024); this.maxItemId = maxItemId; this.totalFeedbacks = 0; } @@ -61,21 +62,19 @@ public class PositiveOnlyFeedback { public int[] getUsers() { final int size = rows.size(); final int[] keys = new int[size]; - final IMapIterator<IntArrayList> itor = rows.entries(); + final IntIterator itor = rows.keySet().iterator(); for (int i = 0; i < size; i++) { - if (itor.next() == -1) { + if (!itor.hasNext()) { throw new IllegalStateException(); } - int key = itor.getKey(); + int key = itor.nextInt(); keys[i] = key; } return keys; } public void getUsers(@Nonnull final BitSet bitset) { - final IMapIterator<IntArrayList> itor = rows.entries(); - while (itor.next() != -1) { - int key = itor.getKey(); + for (int key : rows.keySet()) { bitset.set(key); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/main/java/hivemall/math/matrix/builders/ColumnMajorDenseMatrixBuilder.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/builders/ColumnMajorDenseMatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/ColumnMajorDenseMatrixBuilder.java index 9cae1c7..152ac02 100644 --- a/core/src/main/java/hivemall/math/matrix/builders/ColumnMajorDenseMatrixBuilder.java +++ b/core/src/main/java/hivemall/math/matrix/builders/ColumnMajorDenseMatrixBuilder.java @@ -20,8 +20,9 @@ package hivemall.math.matrix.builders; import hivemall.math.matrix.dense.ColumnMajorDenseMatrix2d; import hivemall.utils.collections.arrays.SparseDoubleArray; -import hivemall.utils.collections.maps.IntOpenHashTable; -import hivemall.utils.collections.maps.IntOpenHashTable.IMapIterator; +import it.unimi.dsi.fastutil.ints.Int2ObjectMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectMaps; +import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; import javax.annotation.Nonnegative; import javax.annotation.Nonnull; @@ -29,13 +30,13 @@ import javax.annotation.Nonnull; public final class ColumnMajorDenseMatrixBuilder extends MatrixBuilder { @Nonnull - private final IntOpenHashTable<SparseDoubleArray> col2rows; + private final Int2ObjectMap<SparseDoubleArray> col2rows; private int row; private int maxNumColumns; private int nnz; public ColumnMajorDenseMatrixBuilder(int initSize) { - this.col2rows = new IntOpenHashTable<SparseDoubleArray>(initSize); + this.col2rows = new Int2ObjectOpenHashMap<SparseDoubleArray>(initSize); this.row = 0; this.maxNumColumns = 0; this.nnz = 0; @@ -68,10 +69,9 @@ public final class ColumnMajorDenseMatrixBuilder extends MatrixBuilder { public ColumnMajorDenseMatrix2d buildMatrix() { final double[][] data = new double[maxNumColumns][]; - final IMapIterator<SparseDoubleArray> itor = col2rows.entries(); - while (itor.next() != -1) { - int col = itor.getKey(); - SparseDoubleArray rows = itor.getValue(); + for (Int2ObjectMap.Entry<SparseDoubleArray> e : Int2ObjectMaps.fastIterable(col2rows)) { + int col = e.getIntKey(); + SparseDoubleArray rows = e.getValue(); data[col] = rows.toArray(); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/main/java/hivemall/mf/FactorizedModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/mf/FactorizedModel.java b/core/src/main/java/hivemall/mf/FactorizedModel.java index 1b7140f..2b32dbc 100644 --- a/core/src/main/java/hivemall/mf/FactorizedModel.java +++ b/core/src/main/java/hivemall/mf/FactorizedModel.java @@ -18,8 +18,9 @@ */ package hivemall.mf; -import hivemall.utils.collections.maps.IntOpenHashTable; import hivemall.utils.math.MathUtils; +import it.unimi.dsi.fastutil.ints.Int2ObjectMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; import java.util.Random; @@ -42,10 +43,10 @@ public final class FactorizedModel { private int minIndex, maxIndex; @Nonnull private Rating meanRating; - private IntOpenHashTable<Rating[]> users; - private IntOpenHashTable<Rating[]> items; - private IntOpenHashTable<Rating> userBias; - private IntOpenHashTable<Rating> itemBias; + private Int2ObjectMap<Rating[]> users; + private Int2ObjectMap<Rating[]> items; + private Int2ObjectMap<Rating> userBias; + private Int2ObjectMap<Rating> itemBias; private final Random[] randU, randI; @@ -67,10 +68,10 @@ public final class FactorizedModel { this.minIndex = 0; this.maxIndex = 0; this.meanRating = ratingInitializer.newRating(meanRating); - this.users = new IntOpenHashTable<Rating[]>(expectedSize); - this.items = new IntOpenHashTable<Rating[]>(expectedSize); - this.userBias = new IntOpenHashTable<Rating>(expectedSize); - this.itemBias = new IntOpenHashTable<Rating>(expectedSize); + this.users = new Int2ObjectOpenHashMap<Rating[]>(expectedSize); + this.items = new Int2ObjectOpenHashMap<Rating[]>(expectedSize); + this.userBias = new Int2ObjectOpenHashMap<Rating>(expectedSize); + this.itemBias = new Int2ObjectOpenHashMap<Rating>(expectedSize); this.randU = newRandoms(factor, 31L); this.randI = newRandoms(factor, 41L); } @@ -105,6 +106,7 @@ public final class FactorizedModel { } + @Nonnull private static Random[] newRandoms(@Nonnull final int size, final long seed) { final Random[] rand = new Random[size]; for (int i = 0, len = rand.length; i < len; i++) { @@ -130,17 +132,17 @@ public final class FactorizedModel { return meanRating.getWeight(); } - public void setMeanRating(float rating) { + public void setMeanRating(final float rating) { meanRating.setWeight(rating); } @Nullable - public Rating[] getUserVector(int u) { + public Rating[] getUserVector(final int u) { return getUserVector(u, false); } @Nullable - public Rating[] getUserVector(int u, boolean init) { + public Rating[] getUserVector(final int u, final boolean init) { Rating[] v = users.get(u); if (init && v == null) { v = new Rating[factor]; @@ -164,7 +166,7 @@ public final class FactorizedModel { } @Nullable - public Rating[] getItemVector(int i) { + public Rating[] getItemVector(final int i) { return getItemVector(i, false); } @@ -193,7 +195,7 @@ public final class FactorizedModel { } @Nonnull - public Rating userBias(int u) { + public Rating userBias(final int u) { Rating b = userBias.get(u); if (b == null) { b = ratingInitializer.newRating(0.f); // dummy @@ -202,15 +204,15 @@ public final class FactorizedModel { return b; } - public float getUserBias(int u) { - Rating b = userBias.get(u); + public float getUserBias(final int u) { + final Rating b = userBias.get(u); if (b == null) { return 0.f; } return b.getWeight(); } - public void setUserBias(int u, float value) { + public void setUserBias(final int u, final float value) { Rating b = userBias.get(u); if (b == null) { b = ratingInitializer.newRating(value); @@ -220,7 +222,7 @@ public final class FactorizedModel { } @Nonnull - public Rating itemBias(int i) { + public Rating itemBias(final int i) { Rating b = itemBias.get(i); if (b == null) { b = ratingInitializer.newRating(0.f); // dummy @@ -230,19 +232,19 @@ public final class FactorizedModel { } @Nullable - public Rating getItemBiasObject(int i) { + public Rating getItemBiasObject(final int i) { return itemBias.get(i); } - public float getItemBias(int i) { - Rating b = itemBias.get(i); + public float getItemBias(final int i) { + final Rating b = itemBias.get(i); if (b == null) { return 0.f; } return b.getWeight(); } - public void setItemBias(int i, float value) { + public void setItemBias(final int i, final float value) { Rating b = itemBias.get(i); if (b == null) { b = ratingInitializer.newRating(value); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/main/java/hivemall/model/AbstractPredictionModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/AbstractPredictionModel.java b/core/src/main/java/hivemall/model/AbstractPredictionModel.java index cd298a7..be7b2e5 100644 --- a/core/src/main/java/hivemall/model/AbstractPredictionModel.java +++ b/core/src/main/java/hivemall/model/AbstractPredictionModel.java @@ -22,8 +22,10 @@ import hivemall.annotations.InternalAPI; import hivemall.mix.MixedWeight; import hivemall.mix.MixedWeight.WeightWithCovar; import hivemall.mix.MixedWeight.WeightWithDelta; -import hivemall.utils.collections.maps.IntOpenHashTable; -import hivemall.utils.collections.maps.OpenHashMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; +import it.unimi.dsi.fastutil.objects.Object2ObjectMap; +import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -37,8 +39,8 @@ public abstract class AbstractPredictionModel implements PredictionModel { private long numMixed; private boolean cancelMixRequest; - private IntOpenHashTable<MixedWeight> mixedRequests_i; - private OpenHashMap<Object, MixedWeight> mixedRequests_o; + private Int2ObjectMap<MixedWeight> mixedRequests_i; + private Object2ObjectMap<Object, MixedWeight> mixedRequests_o; public AbstractPredictionModel() { this.numMixed = 0L; @@ -58,9 +60,9 @@ public abstract class AbstractPredictionModel implements PredictionModel { this.cancelMixRequest = cancelMixRequest; if (cancelMixRequest) { if (isDenseModel()) { - this.mixedRequests_i = new IntOpenHashTable<MixedWeight>(327680); + this.mixedRequests_i = new Int2ObjectOpenHashMap<MixedWeight>(327680); } else { - this.mixedRequests_o = new OpenHashMap<Object, MixedWeight>(327680); + this.mixedRequests_o = new Object2ObjectOpenHashMap<Object, MixedWeight>(327680); } } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/main/java/hivemall/model/NewSparseModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/NewSparseModel.java b/core/src/main/java/hivemall/model/NewSparseModel.java index 5c0a6c7..99034e0 100644 --- a/core/src/main/java/hivemall/model/NewSparseModel.java +++ b/core/src/main/java/hivemall/model/NewSparseModel.java @@ -23,7 +23,7 @@ import hivemall.model.WeightValueWithClock.WeightValueParamsF2Clock; import hivemall.model.WeightValueWithClock.WeightValueParamsF3Clock; import hivemall.model.WeightValueWithClock.WeightValueWithCovarClock; import hivemall.utils.collections.IMapIterator; -import hivemall.utils.collections.maps.OpenHashMap; +import hivemall.utils.collections.maps.OpenHashTable; import javax.annotation.Nonnull; @@ -34,7 +34,7 @@ public final class NewSparseModel extends AbstractPredictionModel { private static final Log logger = LogFactory.getLog(NewSparseModel.class); @Nonnull - private final OpenHashMap<Object, IWeightValue> weights; + private final OpenHashTable<Object, IWeightValue> weights; private final boolean hasCovar; private boolean clockEnabled; @@ -44,7 +44,7 @@ public final class NewSparseModel extends AbstractPredictionModel { public NewSparseModel(int size, boolean hasCovar) { super(); - this.weights = new OpenHashMap<Object, IWeightValue>(size); + this.weights = new OpenHashTable<Object, IWeightValue>(size); this.hasCovar = hasCovar; this.clockEnabled = false; } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/main/java/hivemall/model/SparseModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/SparseModel.java b/core/src/main/java/hivemall/model/SparseModel.java index 65e751d..8028cab 100644 --- a/core/src/main/java/hivemall/model/SparseModel.java +++ b/core/src/main/java/hivemall/model/SparseModel.java @@ -23,7 +23,7 @@ import hivemall.model.WeightValueWithClock.WeightValueParamsF2Clock; import hivemall.model.WeightValueWithClock.WeightValueParamsF3Clock; import hivemall.model.WeightValueWithClock.WeightValueWithCovarClock; import hivemall.utils.collections.IMapIterator; -import hivemall.utils.collections.maps.OpenHashMap; +import hivemall.utils.collections.maps.OpenHashTable; import javax.annotation.Nonnull; @@ -34,13 +34,13 @@ public final class SparseModel extends AbstractPredictionModel { private static final Log logger = LogFactory.getLog(SparseModel.class); @Nonnull - private final OpenHashMap<Object, IWeightValue> weights; + private final OpenHashTable<Object, IWeightValue> weights; private final boolean hasCovar; private boolean clockEnabled; public SparseModel(int size, boolean hasCovar) { super(); - this.weights = new OpenHashMap<Object, IWeightValue>(size); + this.weights = new OpenHashTable<Object, IWeightValue>(size); this.hasCovar = hasCovar; this.clockEnabled = false; } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java b/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java index 775d7d0..37c8f7b 100644 --- a/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java +++ b/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java @@ -26,6 +26,7 @@ import hivemall.utils.math.MathUtils; import java.util.Arrays; import java.util.Map; +import javax.annotation.Nonnegative; import javax.annotation.Nonnull; import javax.annotation.concurrent.NotThreadSafe; @@ -36,7 +37,8 @@ public final class DenseOptimizerFactory { private static final Log LOG = LogFactory.getLog(DenseOptimizerFactory.class); @Nonnull - public static Optimizer create(int ndims, @Nonnull Map<String, String> options) { + public static Optimizer create(@Nonnegative final int ndims, + @Nonnull final Map<String, String> options) { final String optimizerName = options.get("optimizer"); if (optimizerName == null) { throw new IllegalArgumentException("`optimizer` not defined"); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/main/java/hivemall/optimizer/Optimizer.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/Optimizer.java b/core/src/main/java/hivemall/optimizer/Optimizer.java index 0f82833..7bf1e84 100644 --- a/core/src/main/java/hivemall/optimizer/Optimizer.java +++ b/core/src/main/java/hivemall/optimizer/Optimizer.java @@ -90,7 +90,7 @@ public interface Optimizer { private final IWeightValue weightValueReused; - public SGD(final Map<String, String> options) { + public SGD(@Nonnull Map<String, String> options) { super(options); this.weightValueReused = new WeightValue(0.f); } @@ -114,7 +114,7 @@ public interface Optimizer { private final float eps; private final float scale; - public AdaGrad(Map<String, String> options) { + public AdaGrad(@Nonnull Map<String, String> options) { super(options); this.eps = Primitives.parseFloat(options.get("eps"), 1.0f); this.scale = Primitives.parseFloat(options.get("scale"), 100.0f); @@ -141,7 +141,7 @@ public interface Optimizer { private final float eps; private final float scale; - public AdaDelta(Map<String, String> options) { + public AdaDelta(@Nonnull Map<String, String> options) { super(options); this.decay = Primitives.parseFloat(options.get("decay"), 0.95f); this.eps = Primitives.parseFloat(options.get("eps"), 1e-6f); @@ -184,7 +184,7 @@ public interface Optimizer { private final float gamma; private final float eps_hat; - public Adam(Map<String, String> options) { + public Adam(@Nonnull Map<String, String> options) { super(options); this.beta = Primitives.parseFloat(options.get("beta"), 0.9f); this.gamma = Primitives.parseFloat(options.get("gamma"), 0.999f); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java b/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java index 12e0d71..153215d 100644 --- a/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java +++ b/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java @@ -20,10 +20,12 @@ package hivemall.optimizer; import hivemall.model.IWeightValue; import hivemall.model.WeightValue; -import hivemall.utils.collections.maps.OpenHashMap; +import it.unimi.dsi.fastutil.objects.Object2ObjectMap; +import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap; import java.util.Map; +import javax.annotation.Nonnegative; import javax.annotation.Nonnull; import javax.annotation.concurrent.NotThreadSafe; @@ -34,7 +36,8 @@ public final class SparseOptimizerFactory { private static final Log LOG = LogFactory.getLog(SparseOptimizerFactory.class); @Nonnull - public static Optimizer create(int ndims, @Nonnull Map<String, String> options) { + public static Optimizer create(@Nonnull final int ndims, + @Nonnull final Map<String, String> options) { final String optimizerName = options.get("optimizer"); if (optimizerName == null) { throw new IllegalArgumentException("`optimizer` not defined"); @@ -78,11 +81,11 @@ public final class SparseOptimizerFactory { static final class AdaDelta extends Optimizer.AdaDelta { @Nonnull - private final OpenHashMap<Object, IWeightValue> auxWeights; + private final Object2ObjectMap<Object, IWeightValue> auxWeights; - public AdaDelta(int size, Map<String, String> options) { + public AdaDelta(@Nonnegative int size, @Nonnull Map<String, String> options) { super(options); - this.auxWeights = new OpenHashMap<Object, IWeightValue>(size); + this.auxWeights = new Object2ObjectOpenHashMap<Object, IWeightValue>(size); } @Override @@ -103,11 +106,11 @@ public final class SparseOptimizerFactory { static final class AdaGrad extends Optimizer.AdaGrad { @Nonnull - private final OpenHashMap<Object, IWeightValue> auxWeights; + private final Object2ObjectMap<Object, IWeightValue> auxWeights; - public AdaGrad(int size, Map<String, String> options) { + public AdaGrad(@Nonnegative int size, @Nonnull Map<String, String> options) { super(options); - this.auxWeights = new OpenHashMap<Object, IWeightValue>(size); + this.auxWeights = new Object2ObjectOpenHashMap<Object, IWeightValue>(size); } @Override @@ -128,11 +131,11 @@ public final class SparseOptimizerFactory { static final class Adam extends Optimizer.Adam { @Nonnull - private final OpenHashMap<Object, IWeightValue> auxWeights; + private final Object2ObjectMap<Object, IWeightValue> auxWeights; - public Adam(int size, Map<String, String> options) { + public Adam(@Nonnegative int size, @Nonnull Map<String, String> options) { super(options); - this.auxWeights = new OpenHashMap<Object, IWeightValue>(size); + this.auxWeights = new Object2ObjectOpenHashMap<Object, IWeightValue>(size); } @Override @@ -153,12 +156,12 @@ public final class SparseOptimizerFactory { static final class AdagradRDA extends Optimizer.AdagradRDA { @Nonnull - private final OpenHashMap<Object, IWeightValue> auxWeights; + private final Object2ObjectMap<Object, IWeightValue> auxWeights; - public AdagradRDA(int size, @Nonnull Optimizer.AdaGrad optimizerImpl, + public AdagradRDA(@Nonnegative int size, @Nonnull Optimizer.AdaGrad optimizerImpl, @Nonnull Map<String, String> options) { super(optimizerImpl, options); - this.auxWeights = new OpenHashMap<Object, IWeightValue>(size); + this.auxWeights = new Object2ObjectOpenHashMap<Object, IWeightValue>(size); } @Override http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/main/java/hivemall/recommend/SlimUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/recommend/SlimUDTF.java b/core/src/main/java/hivemall/recommend/SlimUDTF.java index c6363f1..395221f 100644 --- a/core/src/main/java/hivemall/recommend/SlimUDTF.java +++ b/core/src/main/java/hivemall/recommend/SlimUDTF.java @@ -24,9 +24,6 @@ import hivemall.common.ConversionState; import hivemall.math.matrix.FloatMatrix; import hivemall.math.matrix.sparse.floats.DoKFloatMatrix; import hivemall.math.vector.VectorProcedure; -import hivemall.utils.collections.maps.Int2FloatOpenHashTable; -import hivemall.utils.collections.maps.IntOpenHashTable; -import hivemall.utils.collections.maps.IntOpenHashTable.IMapIterator; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.io.FileUtils; import hivemall.utils.io.NioStatefullSegment; @@ -36,16 +33,22 @@ import hivemall.utils.lang.SizeOf; import hivemall.utils.lang.mutable.MutableDouble; import hivemall.utils.lang.mutable.MutableInt; import hivemall.utils.lang.mutable.MutableObject; +import it.unimi.dsi.fastutil.ints.Int2FloatMap; +import it.unimi.dsi.fastutil.ints.Int2FloatMaps; +import it.unimi.dsi.fastutil.ints.Int2FloatOpenHashMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectMaps; +import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; +import it.unimi.dsi.fastutil.ints.IntOpenHashSet; +import it.unimi.dsi.fastutil.ints.IntSet; import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; -import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Set; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -122,9 +125,9 @@ public class SlimUDTF extends UDTFWithOptions { private int _previousItemId; @Nullable - private transient Int2FloatOpenHashTable _ri; + private transient Int2FloatMap _ri; @Nullable - private transient IntOpenHashTable<Int2FloatOpenHashTable> _kNNi; + private transient Int2ObjectMap<Int2FloatMap> _kNNi; /** The number of elements in kNNi */ @Nullable private transient MutableInt _nnzKNNi; @@ -290,15 +293,14 @@ public class SlimUDTF extends UDTFWithOptions { } int itemJ = PrimitiveObjectInspectorUtils.getInt(args[3], itemJOI); - Int2FloatOpenHashTable rj = int2floatMap(itemJ, rjOI.getMap(args[4]), rjKeyOI, rjValueOI, - _dataMatrix); + Int2FloatMap rj = int2floatMap(itemJ, rjOI.getMap(args[4]), rjKeyOI, rjValueOI, _dataMatrix); train(itemI, _ri, _kNNi, itemJ, rj); _observedTrainingExamples++; } private void recordTrainingInput(final int itemI, - @Nonnull final IntOpenHashTable<Int2FloatOpenHashTable> knnItems, final int numKNNItems) + @Nonnull final Int2ObjectMap<Int2FloatMap> knnItems, final int numKNNItems) throws HiveException { ByteBuffer buf = this._inputBuf; NioStatefullSegment dst = this._fileIO; @@ -334,17 +336,15 @@ public class SlimUDTF extends UDTFWithOptions { buf.putInt(itemI); buf.putInt(knnItems.size()); - final IMapIterator<Int2FloatOpenHashTable> entries = knnItems.entries(); - while (entries.next() != -1) { - int user = entries.getKey(); + for (Int2ObjectMap.Entry<Int2FloatMap> e1 : Int2ObjectMaps.fastIterable(knnItems)) { + int user = e1.getIntKey(); buf.putInt(user); - Int2FloatOpenHashTable ru = entries.getValue(); + Int2FloatMap ru = e1.getValue(); buf.putInt(ru.size()); - final Int2FloatOpenHashTable.IMapIterator itor = ru.entries(); - while (itor.next() != -1) { - buf.putInt(itor.getKey()); - buf.putFloat(itor.getValue()); + for (Int2FloatMap.Entry e2 : Int2FloatMaps.fastIterable(ru)) { + buf.putInt(e2.getIntKey()); + buf.putFloat(e2.getFloatValue()); } } } @@ -360,9 +360,9 @@ public class SlimUDTF extends UDTFWithOptions { srcBuf.clear(); } - private void train(final int itemI, @Nonnull final Int2FloatOpenHashTable ri, - @Nonnull final IntOpenHashTable<Int2FloatOpenHashTable> kNNi, final int itemJ, - @Nonnull final Int2FloatOpenHashTable rj) { + private void train(final int itemI, @Nonnull final Int2FloatMap ri, + @Nonnull final Int2ObjectMap<Int2FloatMap> kNNi, final int itemJ, + @Nonnull final Int2FloatMap rj) { final FloatMatrix W = _weightMatrix; final int N = rj.size(); @@ -374,11 +374,10 @@ public class SlimUDTF extends UDTFWithOptions { double rateSum = 0.d; double lossSum = 0.d; - final Int2FloatOpenHashTable.IMapIterator itor = rj.entries(); - while (itor.next() != -1) { - int user = itor.getKey(); - double ruj = itor.getValue(); - double rui = ri.get(user, 0.f); + for (Int2FloatMap.Entry e : Int2FloatMaps.fastIterable(rj)) { + int user = e.getIntKey(); + double ruj = e.getFloatValue(); + double rui = ri.getOrDefault(user, 0.f); double eui = rui - predict(user, itemI, kNNi, itemJ, W); gradSum += ruj * eui; @@ -396,8 +395,8 @@ public class SlimUDTF extends UDTFWithOptions { W.set(itemI, itemJ, getUpdateTerm(gradSum, rateSum, l1, l2)); } - private void train(final int itemI, - @Nonnull final IntOpenHashTable<Int2FloatOpenHashTable> knnItems, final int itemJ) { + private void train(final int itemI, @Nonnull final Int2ObjectMap<Int2FloatMap> knnItems, + final int itemJ) { final FloatMatrix A = _dataMatrix; final FloatMatrix W = _weightMatrix; @@ -433,21 +432,20 @@ public class SlimUDTF extends UDTFWithOptions { } private static double predict(final int user, final int itemI, - @Nonnull final IntOpenHashTable<Int2FloatOpenHashTable> knnItems, - final int excludeIndex, @Nonnull final FloatMatrix weightMatrix) { - final Int2FloatOpenHashTable kNNu = knnItems.get(user); + @Nonnull final Int2ObjectMap<Int2FloatMap> knnItems, final int excludeIndex, + @Nonnull final FloatMatrix weightMatrix) { + final Int2FloatMap kNNu = knnItems.get(user); if (kNNu == null) { return 0.d; } double pred = 0.d; - final Int2FloatOpenHashTable.IMapIterator itor = kNNu.entries(); - while (itor.next() != -1) { - final int itemK = itor.getKey(); + for (Int2FloatMap.Entry e : Int2FloatMaps.fastIterable(kNNu)) { + final int itemK = e.getIntKey(); if (itemK == excludeIndex) { continue; } - float ruk = itor.getValue(); + float ruk = e.getFloatValue(); pred += ruk * weightMatrix.get(itemI, itemK, 0.d); } return pred; @@ -624,12 +622,12 @@ public class SlimUDTF extends UDTFWithOptions { final int itemI = buf.getInt(); final int knnSize = buf.getInt(); - final IntOpenHashTable<Int2FloatOpenHashTable> knnItems = new IntOpenHashTable<>(1024); - final Set<Integer> pairItems = new HashSet<>(); + final Int2ObjectMap<Int2FloatMap> knnItems = new Int2ObjectOpenHashMap<>(1024); + final IntSet pairItems = new IntOpenHashSet(); for (int i = 0; i < knnSize; i++) { int user = buf.getInt(); int ruSize = buf.getInt(); - Int2FloatOpenHashTable ru = new Int2FloatOpenHashTable(ruSize); + Int2FloatMap ru = new Int2FloatOpenHashMap(ruSize); ru.defaultReturnValue(0.f); for (int j = 0; j < ruSize; j++) { @@ -677,16 +675,15 @@ public class SlimUDTF extends UDTFWithOptions { } @Nonnull - private static IntOpenHashTable<Int2FloatOpenHashTable> kNNentries( - @Nonnull final Object kNNiObj, @Nonnull final MapObjectInspector knnItemsOI, + private static Int2ObjectMap<Int2FloatMap> kNNentries(@Nonnull final Object kNNiObj, + @Nonnull final MapObjectInspector knnItemsOI, @Nonnull final PrimitiveObjectInspector knnItemsKeyOI, @Nonnull final MapObjectInspector knnItemsValueOI, @Nonnull final PrimitiveObjectInspector knnItemsValueKeyOI, @Nonnull final PrimitiveObjectInspector knnItemsValueValueOI, - @Nullable IntOpenHashTable<Int2FloatOpenHashTable> knnItems, - @Nonnull final MutableInt nnzKNNi) { + @Nullable Int2ObjectMap<Int2FloatMap> knnItems, @Nonnull final MutableInt nnzKNNi) { if (knnItems == null) { - knnItems = new IntOpenHashTable<>(1024); + knnItems = new Int2ObjectOpenHashMap<>(1024); } else { knnItems.clear(); } @@ -694,7 +691,7 @@ public class SlimUDTF extends UDTFWithOptions { int numElementOfKNNItems = 0; for (Map.Entry<?, ?> entry : knnItemsOI.getMap(kNNiObj).entrySet()) { int user = PrimitiveObjectInspectorUtils.getInt(entry.getKey(), knnItemsKeyOI); - Int2FloatOpenHashTable ru = int2floatMap(knnItemsValueOI.getMap(entry.getValue()), + Int2FloatMap ru = int2floatMap(knnItemsValueOI.getMap(entry.getValue()), knnItemsValueKeyOI, knnItemsValueValueOI); knnItems.put(user, ru); numElementOfKNNItems += ru.size(); @@ -705,10 +702,10 @@ public class SlimUDTF extends UDTFWithOptions { } @Nonnull - private static Int2FloatOpenHashTable int2floatMap(@Nonnull final Map<?, ?> map, + private static Int2FloatMap int2floatMap(@Nonnull final Map<?, ?> map, @Nonnull final PrimitiveObjectInspector keyOI, @Nonnull final PrimitiveObjectInspector valueOI) { - final Int2FloatOpenHashTable result = new Int2FloatOpenHashTable(map.size()); + final Int2FloatMap result = new Int2FloatOpenHashMap(map.size()); result.defaultReturnValue(0.f); for (Map.Entry<?, ?> entry : map.entrySet()) { @@ -724,19 +721,19 @@ public class SlimUDTF extends UDTFWithOptions { } @Nonnull - private static Int2FloatOpenHashTable int2floatMap(final int item, - @Nonnull final Map<?, ?> map, @Nonnull final PrimitiveObjectInspector keyOI, + private static Int2FloatMap int2floatMap(final int item, @Nonnull final Map<?, ?> map, + @Nonnull final PrimitiveObjectInspector keyOI, @Nonnull final PrimitiveObjectInspector valueOI, @Nullable final FloatMatrix dataMatrix) { return int2floatMap(item, map, keyOI, valueOI, dataMatrix, null); } @Nonnull - private static Int2FloatOpenHashTable int2floatMap(final int item, - @Nonnull final Map<?, ?> map, @Nonnull final PrimitiveObjectInspector keyOI, + private static Int2FloatMap int2floatMap(final int item, @Nonnull final Map<?, ?> map, + @Nonnull final PrimitiveObjectInspector keyOI, @Nonnull final PrimitiveObjectInspector valueOI, - @Nullable final FloatMatrix dataMatrix, @Nullable Int2FloatOpenHashTable dst) { + @Nullable final FloatMatrix dataMatrix, @Nullable Int2FloatMap dst) { if (dst == null) { - dst = new Int2FloatOpenHashTable(map.size()); + dst = new Int2FloatOpenHashMap(map.size()); dst.defaultReturnValue(0.f); } else { dst.clear(); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/main/java/hivemall/tools/mapred/DistributedCacheLookupUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/mapred/DistributedCacheLookupUDF.java b/core/src/main/java/hivemall/tools/mapred/DistributedCacheLookupUDF.java index 366b74b..2794476 100644 --- a/core/src/main/java/hivemall/tools/mapred/DistributedCacheLookupUDF.java +++ b/core/src/main/java/hivemall/tools/mapred/DistributedCacheLookupUDF.java @@ -19,10 +19,11 @@ package hivemall.tools.mapred; import hivemall.ftvec.ExtractFeatureUDF; -import hivemall.utils.collections.maps.OpenHashMap; import hivemall.utils.hadoop.HadoopUtils; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.io.IOUtils; +import it.unimi.dsi.fastutil.objects.Object2ObjectMap; +import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap; import java.io.BufferedReader; import java.io.File; @@ -66,7 +67,7 @@ public final class DistributedCacheLookupUDF extends GenericUDF { private ListObjectInspector keysInputOI; private ListObjectInspector valuesInputOI; - private OpenHashMap<Object, Object> cache; + private Object2ObjectMap<Object, Object> cache; @Override public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { @@ -125,7 +126,8 @@ public final class DistributedCacheLookupUDF extends GenericUDF { "parseKey=true is only available for string typed key(s)"); } - final OpenHashMap<Object, Object> map = new OpenHashMap<Object, Object>(8192); + final Object2ObjectMap<Object, Object> map = new Object2ObjectOpenHashMap<Object, Object>( + 8192); try { loadValues(map, new File(filepath), keyInputOI, valueInputOI); this.cache = map; @@ -138,7 +140,7 @@ public final class DistributedCacheLookupUDF extends GenericUDF { return outputOI; } - private static void loadValues(OpenHashMap<Object, Object> map, File file, + private static void loadValues(Object2ObjectMap<Object, Object> map, File file, PrimitiveObjectInspector keyOI, PrimitiveObjectInspector valueOI) throws IOException, SerDeException { if (!file.exists()) {
