Close #51: [HIVEMALL-75] Support Sparse Vector Format as the input of RandomForest
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/8dc3a024 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/8dc3a024 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/8dc3a024 Branch: refs/heads/master Commit: 8dc3a024d9b2708f297a886a3256e7107bc276f9 Parents: 7956b5f Author: myui <m...@apache.org> Authored: Mon Apr 10 06:31:49 2017 +0900 Committer: myui <yuin...@gmail.com> Committed: Mon Apr 10 06:31:49 2017 +0900 ---------------------------------------------------------------------- core/pom.xml | 10 +- .../java/hivemall/annotations/Immutable.java | 34 + .../main/java/hivemall/annotations/Mutable.java | 36 ++ .../KernelExpansionPassiveAggressiveUDTF.java | 4 +- .../java/hivemall/common/ReservoirSampler.java | 100 --- .../java/hivemall/fm/FFMPredictionModel.java | 4 +- .../hivemall/fm/FFMStringFeatureMapModel.java | 2 +- .../java/hivemall/fm/FMIntFeatureMapModel.java | 4 +- .../hivemall/fm/FMStringFeatureMapModel.java | 2 +- .../fm/FieldAwareFactorizationMachineModel.java | 4 +- .../fm/FieldAwareFactorizationMachineUDTF.java | 4 +- .../hivemall/ftvec/ranking/BprSamplingUDTF.java | 2 +- .../ranking/PerEventPositiveOnlyFeedback.java | 2 +- .../ftvec/ranking/PositiveOnlyFeedback.java | 6 +- .../hivemall/math/matrix/AbstractMatrix.java | 105 +++ .../hivemall/math/matrix/ColumnMajorMatrix.java | 59 ++ .../main/java/hivemall/math/matrix/Matrix.java | 127 ++++ .../java/hivemall/math/matrix/MatrixUtils.java | 73 +++ .../hivemall/math/matrix/RowMajorMatrix.java | 69 ++ .../math/matrix/builders/CSCMatrixBuilder.java | 121 ++++ .../math/matrix/builders/CSRMatrixBuilder.java | 77 +++ .../builders/ColumnMajorDenseMatrixBuilder.java | 81 +++ .../math/matrix/builders/DoKMatrixBuilder.java | 56 ++ .../math/matrix/builders/MatrixBuilder.java | 91 +++ .../builders/RowMajorDenseMatrixBuilder.java | 79 +++ .../matrix/dense/ColumnMajorDenseMatrix2d.java | 300 +++++++++ .../matrix/dense/RowMajorDenseMatrix2d.java | 349 ++++++++++ .../math/matrix/ints/AbstractIntMatrix.java | 112 ++++ .../ints/ColumnMajorDenseIntMatrix2d.java | 172 +++++ .../math/matrix/ints/ColumnMajorIntMatrix.java | 39 ++ .../hivemall/math/matrix/ints/DoKIntMatrix.java | 277 ++++++++ .../hivemall/math/matrix/ints/IntMatrix.java | 104 +++ .../hivemall/math/matrix/sparse/CSCMatrix.java | 289 +++++++++ .../hivemall/math/matrix/sparse/CSRMatrix.java | 282 ++++++++ .../hivemall/math/matrix/sparse/DoKMatrix.java | 332 ++++++++++ .../hivemall/math/random/CommonsMathRandom.java | 63 ++ .../java/hivemall/math/random/JavaRandom.java | 61 ++ .../main/java/hivemall/math/random/PRNG.java | 39 ++ .../random/RandomNumberGeneratorFactory.java | 103 +++ .../java/hivemall/math/random/SmileRandom.java | 63 ++ .../hivemall/math/vector/AbstractVector.java | 44 ++ .../java/hivemall/math/vector/DenseVector.java | 90 +++ .../java/hivemall/math/vector/SparseVector.java | 76 +++ .../main/java/hivemall/math/vector/Vector.java | 46 ++ .../hivemall/math/vector/VectorProcedure.java | 33 + .../java/hivemall/matrix/CSRMatrixBuilder.java | 83 --- .../hivemall/matrix/DenseMatrixBuilder.java | 79 --- core/src/main/java/hivemall/matrix/Matrix.java | 92 --- .../java/hivemall/matrix/MatrixBuilder.java | 89 --- .../java/hivemall/matrix/ReadOnlyCSRMatrix.java | 135 ---- .../hivemall/matrix/ReadOnlyDenseMatrix2d.java | 102 --- .../main/java/hivemall/mf/FactorizedModel.java | 2 +- .../hivemall/model/AbstractPredictionModel.java | 4 +- .../main/java/hivemall/model/SparseModel.java | 2 +- .../src/main/java/hivemall/smile/ModelType.java | 85 --- .../smile/classification/DecisionTree.java | 495 +++++++------- .../GradientTreeBoostingClassifierUDTF.java | 228 +++---- .../smile/classification/PredictionHandler.java | 27 + .../RandomForestClassifierUDTF.java | 366 +++++++---- .../java/hivemall/smile/data/Attribute.java | 44 +- .../regression/RandomForestRegressionUDTF.java | 211 +++--- .../smile/regression/RegressionTree.java | 377 ++++++----- .../smile/tools/RandomForestEnsembleUDAF.java | 328 +++++++--- .../hivemall/smile/tools/TreePredictUDF.java | 407 +++++------- .../hivemall/smile/utils/SmileExtUtils.java | 215 +++++-- .../main/java/hivemall/smile/vm/Operation.java | 52 -- .../java/hivemall/smile/vm/StackMachine.java | 300 --------- .../hivemall/smile/vm/VMRuntimeException.java | 32 - .../tools/mapred/DistributedCacheLookupUDF.java | 2 +- .../hivemall/utils/collections/DoubleArray.java | 43 -- .../utils/collections/DoubleArray3D.java | 147 ----- .../utils/collections/DoubleArrayList.java | 168 ----- .../utils/collections/FixedIntArray.java | 87 --- .../utils/collections/FloatArrayList.java | 152 ----- .../collections/Int2FloatOpenHashTable.java | 418 ------------ .../utils/collections/Int2IntOpenHashTable.java | 414 ------------ .../collections/Int2LongOpenHashTable.java | 500 -------------- .../hivemall/utils/collections/IntArray.java | 43 -- .../utils/collections/IntArrayList.java | 183 ------ .../utils/collections/IntOpenHashMap.java | 467 -------------- .../utils/collections/IntOpenHashTable.java | 338 ---------- .../java/hivemall/utils/collections/LRUMap.java | 41 -- .../hivemall/utils/collections/OpenHashMap.java | 350 ---------- .../utils/collections/OpenHashTable.java | 412 ------------ .../utils/collections/SparseDoubleArray.java | 213 ------ .../utils/collections/SparseIntArray.java | 210 ------ .../collections/arrays/DenseDoubleArray.java | 92 +++ .../utils/collections/arrays/DenseIntArray.java | 92 +++ .../utils/collections/arrays/DoubleArray.java | 45 ++ .../utils/collections/arrays/DoubleArray3D.java | 147 +++++ .../utils/collections/arrays/FloatArray.java | 45 ++ .../utils/collections/arrays/IntArray.java | 45 ++ .../collections/arrays/SparseDoubleArray.java | 223 +++++++ .../collections/arrays/SparseFloatArray.java | 210 ++++++ .../collections/arrays/SparseIntArray.java | 211 ++++++ .../collections/lists/DoubleArrayList.java | 164 +++++ .../utils/collections/lists/FloatArrayList.java | 162 +++++ .../utils/collections/lists/IntArrayList.java | 179 ++++++ .../utils/collections/lists/LongArrayList.java | 166 +++++ .../maps/Int2FloatOpenHashTable.java | 418 ++++++++++++ .../collections/maps/Int2IntOpenHashTable.java | 414 ++++++++++++ .../collections/maps/Int2LongOpenHashTable.java | 500 ++++++++++++++ .../utils/collections/maps/IntOpenHashMap.java | 467 ++++++++++++++ .../collections/maps/IntOpenHashTable.java | 404 ++++++++++++ .../hivemall/utils/collections/maps/LRUMap.java | 41 ++ .../maps/Long2DoubleOpenHashTable.java | 445 +++++++++++++ .../maps/Long2FloatOpenHashTable.java | 429 ++++++++++++ .../collections/maps/Long2IntOpenHashTable.java | 473 ++++++++++++++ .../utils/collections/maps/OpenHashMap.java | 351 ++++++++++ .../utils/collections/maps/OpenHashTable.java | 413 ++++++++++++ .../utils/collections/sets/IntArraySet.java | 88 +++ .../hivemall/utils/collections/sets/IntSet.java | 38 ++ .../java/hivemall/utils/hadoop/HiveUtils.java | 69 +- .../java/hivemall/utils/lang/ArrayUtils.java | 407 +++++++++++- .../java/hivemall/utils/lang/Primitives.java | 28 + .../java/hivemall/utils/math/MathUtils.java | 22 +- .../java/hivemall/utils/math/MatrixUtils.java | 2 +- .../utils/sampling/IntReservoirSampler.java | 99 +++ .../utils/sampling/ReservoirSampler.java | 100 +++ .../java/hivemall/utils/stream/IntIterator.java | 27 + .../java/hivemall/utils/stream/IntStream.java | 28 + .../java/hivemall/utils/stream/StreamUtils.java | 180 ++++++ .../hivemall/fm/FFMPredictionModelTest.java | 2 +- .../hivemall/math/matrix/MatrixBuilderTest.java | 644 +++++++++++++++++++ .../math/matrix/ints/IntMatrixTest.java | 361 +++++++++++ .../java/hivemall/matrix/MatrixBuilderTest.java | 329 ---------- .../smile/classification/DecisionTreeTest.java | 249 +++---- .../RandomForestClassifierUDTFTest.java | 286 +++++++- .../smile/regression/RegressionTreeTest.java | 81 ++- .../smile/tools/TreePredictUDFTest.java | 61 +- .../hivemall/smile/vm/StackMachineTest.java | 88 --- .../utils/collections/DoubleArray3DTest.java | 147 ----- .../utils/collections/DoubleArrayTest.java | 60 -- .../collections/Int2FloatOpenHashMapTest.java | 96 --- .../collections/Int2LongOpenHashMapTest.java | 105 --- .../utils/collections/IntArrayTest.java | 76 --- .../utils/collections/IntOpenHashMapTest.java | 73 --- .../utils/collections/IntOpenHashTableTest.java | 50 -- .../utils/collections/OpenHashMapTest.java | 91 --- .../utils/collections/OpenHashTableTest.java | 138 ---- .../utils/collections/SparseIntArrayTest.java | 61 -- .../collections/arrays/DoubleArray3DTest.java | 149 +++++ .../collections/arrays/DoubleArrayTest.java | 62 ++ .../utils/collections/arrays/IntArrayTest.java | 79 +++ .../collections/arrays/SparseIntArrayTest.java | 64 ++ .../collections/lists/LongArrayListTest.java | 43 ++ .../maps/Int2FloatOpenHashMapTest.java | 98 +++ .../maps/Int2LongOpenHashMapTest.java | 106 +++ .../collections/maps/IntOpenHashMapTest.java | 75 +++ .../collections/maps/IntOpenHashTableTest.java | 52 ++ .../maps/Long2IntOpenHashMapTest.java | 115 ++++ .../utils/collections/maps/OpenHashMapTest.java | 93 +++ .../collections/maps/OpenHashTableTest.java | 140 ++++ .../hivemall/utils/stream/StreamUtilsTest.java | 86 +++ .../hivemall/classifier/news20-multiclass.gz | Bin 0 -> 396138 bytes .../apache/spark/sql/hive/GroupedDataEx.scala | 2 +- .../spark/sql/hive/HivemallOpsSuite.scala | 21 +- .../spark/sql/hive/HivemallGroupedDataset.scala | 2 +- .../spark/sql/hive/HivemallOpsSuite.scala | 10 +- .../spark/sql/hive/HivemallGroupedDataset.scala | 2 +- .../spark/sql/hive/HivemallOpsSuite.scala | 10 +- 161 files changed, 15287 insertions(+), 8113 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/pom.xml ---------------------------------------------------------------------- diff --git a/core/pom.xml b/core/pom.xml index bf931ac..d7655f4 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -109,7 +109,7 @@ <dependency> <groupId>com.github.haifengl</groupId> <artifactId>smile-core</artifactId> - <version>1.0.3</version> + <version>1.0.4</version> <scope>compile</scope> <exclusions> <exclusion> @@ -130,6 +130,12 @@ <version>3.6.1</version> <scope>compile</scope> </dependency> + <dependency> + <groupId>org.roaringbitmap</groupId> + <artifactId>RoaringBitmap</artifactId> + <version>[0.6,)</version> + <scope>compile</scope> + </dependency> <!-- test scope --> <dependency> @@ -198,6 +204,8 @@ <include>com.github.haifengl:smile-math</include> <include>com.github.haifengl:smile-data</include> <include>org.tukaani:xz</include> + <include>org.apache.commons:commons-math3</include> + <include>org.roaringbitmap:RoaringBitmap</include> </includes> </artifactSet> <transformers> http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/annotations/Immutable.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/annotations/Immutable.java b/core/src/main/java/hivemall/annotations/Immutable.java new file mode 100644 index 0000000..941fa5d --- /dev/null +++ b/core/src/main/java/hivemall/annotations/Immutable.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.annotations; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * The class to which this annotation is applied is Immutable. + */ +@Documented +@Target(ElementType.TYPE) +@Retention(RetentionPolicy.CLASS) +public @interface Immutable { +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/annotations/Mutable.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/annotations/Mutable.java b/core/src/main/java/hivemall/annotations/Mutable.java new file mode 100644 index 0000000..bdac5d9 --- /dev/null +++ b/core/src/main/java/hivemall/annotations/Mutable.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.annotations; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * The class to which this annotation is applied is Mutable. + * + * @see javax.annotation.concurrent.Immutable + */ +@Documented +@Target(ElementType.TYPE) +@Retention(RetentionPolicy.CLASS) +public @interface Mutable { +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTF.java b/core/src/main/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTF.java index 7cb7a58..8534231 100644 --- a/core/src/main/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTF.java +++ b/core/src/main/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTF.java @@ -24,8 +24,8 @@ import hivemall.common.LossFunctions; import hivemall.model.FeatureValue; import hivemall.model.PredictionModel; import hivemall.model.PredictionResult; -import hivemall.utils.collections.Int2FloatOpenHashTable; -import hivemall.utils.collections.Int2FloatOpenHashTable.IMapIterator; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable.IMapIterator; import hivemall.utils.hashing.HashFunction; import hivemall.utils.lang.Preconditions; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/common/ReservoirSampler.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/common/ReservoirSampler.java b/core/src/main/java/hivemall/common/ReservoirSampler.java deleted file mode 100644 index 8846ac1..0000000 --- a/core/src/main/java/hivemall/common/ReservoirSampler.java +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.common; - -import java.util.Arrays; -import java.util.List; -import java.util.Random; - -/** - * Vitter's reservoir sampling implementation that randomly chooses k items from a list containing n - * items. - * - * @link http://en.wikipedia.org/wiki/Reservoir_sampling - * @link http://portal.acm.org/citation.cfm?id=3165 - */ -public final class ReservoirSampler<T> { - - private final T[] samples; - private final int numSamples; - private int position; - - private final Random rand; - - @SuppressWarnings("unchecked") - public ReservoirSampler(int sampleSize) { - if (sampleSize <= 0) { - throw new IllegalArgumentException("sampleSize must be greater than 1: " + sampleSize); - } - this.samples = (T[]) new Object[sampleSize]; - this.numSamples = sampleSize; - this.position = 0; - this.rand = new Random(); - } - - @SuppressWarnings("unchecked") - public ReservoirSampler(int sampleSize, long seed) { - this.samples = (T[]) new Object[sampleSize]; - this.numSamples = sampleSize; - this.position = 0; - this.rand = new Random(seed); - } - - public ReservoirSampler(T[] samples) { - this.samples = samples; - this.numSamples = samples.length; - this.position = 0; - this.rand = new Random(); - } - - public ReservoirSampler(T[] samples, long seed) { - this.samples = samples; - this.numSamples = samples.length; - this.position = 0; - this.rand = new Random(seed); - } - - public T[] getSample() { - return samples; - } - - public List<T> getSamplesAsList() { - return Arrays.asList(samples); - } - - public void add(T item) { - if (item == null) { - return; - } - if (position < numSamples) {// reservoir not yet full, just append - samples[position] = item; - } else {// find a item to replace - int replaceIndex = rand.nextInt(position + 1); - if (replaceIndex < numSamples) { - samples[replaceIndex] = item; - } - } - position++; - } - - public void clear() { - Arrays.fill(samples, null); - this.position = 0; - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/fm/FFMPredictionModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FFMPredictionModel.java b/core/src/main/java/hivemall/fm/FFMPredictionModel.java index 6969d05..befbec9 100644 --- a/core/src/main/java/hivemall/fm/FFMPredictionModel.java +++ b/core/src/main/java/hivemall/fm/FFMPredictionModel.java @@ -21,8 +21,8 @@ package hivemall.fm; import hivemall.utils.buffer.HeapBuffer; import hivemall.utils.codec.VariableByteCodec; import hivemall.utils.codec.ZigZagLEB128Codec; -import hivemall.utils.collections.Int2LongOpenHashTable; -import hivemall.utils.collections.IntOpenHashTable; +import hivemall.utils.collections.maps.Int2LongOpenHashTable; +import hivemall.utils.collections.maps.IntOpenHashTable; import hivemall.utils.io.CompressionStreamFactory.CompressionAlgorithm; import hivemall.utils.io.IOUtils; import hivemall.utils.lang.ArrayUtils; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java b/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java index 4009326..4f445fa 100644 --- a/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java +++ b/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java @@ -22,7 +22,7 @@ import hivemall.fm.Entry.AdaGradEntry; import hivemall.fm.Entry.FTRLEntry; import hivemall.fm.FMHyperParameters.FFMHyperParameters; import hivemall.utils.buffer.HeapBuffer; -import hivemall.utils.collections.Int2LongOpenHashTable; +import hivemall.utils.collections.maps.Int2LongOpenHashTable; import hivemall.utils.lang.NumberUtils; import hivemall.utils.math.MathUtils; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java b/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java index d2a5ed6..19ac287 100644 --- a/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java +++ b/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java @@ -18,8 +18,8 @@ */ package hivemall.fm; -import hivemall.utils.collections.Int2FloatOpenHashTable; -import hivemall.utils.collections.IntOpenHashMap; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; +import hivemall.utils.collections.maps.IntOpenHashMap; import java.util.Arrays; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java b/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java index 10ffaae..cd99046 100644 --- a/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java +++ b/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java @@ -19,7 +19,7 @@ package hivemall.fm; import hivemall.utils.collections.IMapIterator; -import hivemall.utils.collections.OpenHashTable; +import hivemall.utils.collections.maps.OpenHashTable; import javax.annotation.Nonnull; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java index e63797c..76bead8 100644 --- a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java +++ b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java @@ -19,8 +19,8 @@ package hivemall.fm; import hivemall.fm.FMHyperParameters.FFMHyperParameters; -import hivemall.utils.collections.DoubleArray3D; -import hivemall.utils.collections.IntArrayList; +import hivemall.utils.collections.arrays.DoubleArray3D; +import hivemall.utils.collections.lists.IntArrayList; import hivemall.utils.lang.NumberUtils; import java.util.Arrays; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java index fe27269..67dbf87 100644 --- a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java +++ b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java @@ -19,8 +19,8 @@ package hivemall.fm; import hivemall.fm.FMHyperParameters.FFMHyperParameters; -import hivemall.utils.collections.DoubleArray3D; -import hivemall.utils.collections.IntArrayList; +import hivemall.utils.collections.arrays.DoubleArray3D; +import hivemall.utils.collections.lists.IntArrayList; import hivemall.utils.hadoop.HadoopUtils; import hivemall.utils.hadoop.Text3; import hivemall.utils.lang.NumberUtils; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/ftvec/ranking/BprSamplingUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/ranking/BprSamplingUDTF.java b/core/src/main/java/hivemall/ftvec/ranking/BprSamplingUDTF.java index 8e84bd8..ab418ed 100644 --- a/core/src/main/java/hivemall/ftvec/ranking/BprSamplingUDTF.java +++ b/core/src/main/java/hivemall/ftvec/ranking/BprSamplingUDTF.java @@ -19,7 +19,7 @@ package hivemall.ftvec.ranking; import hivemall.UDTFWithOptions; -import hivemall.utils.collections.IntArrayList; +import hivemall.utils.collections.lists.IntArrayList; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.lang.BitUtils; import hivemall.utils.lang.Primitives; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/ftvec/ranking/PerEventPositiveOnlyFeedback.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/ranking/PerEventPositiveOnlyFeedback.java b/core/src/main/java/hivemall/ftvec/ranking/PerEventPositiveOnlyFeedback.java index b5afb99..94bb697 100644 --- a/core/src/main/java/hivemall/ftvec/ranking/PerEventPositiveOnlyFeedback.java +++ b/core/src/main/java/hivemall/ftvec/ranking/PerEventPositiveOnlyFeedback.java @@ -18,7 +18,7 @@ */ package hivemall.ftvec.ranking; -import hivemall.utils.collections.IntArrayList; +import hivemall.utils.collections.lists.IntArrayList; import hivemall.utils.lang.ArrayUtils; import java.util.Random; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java b/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java index 908a0b7..5e9f797 100644 --- a/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java +++ b/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java @@ -18,9 +18,9 @@ */ package hivemall.ftvec.ranking; -import hivemall.utils.collections.IntArrayList; -import hivemall.utils.collections.IntOpenHashMap; -import hivemall.utils.collections.IntOpenHashMap.IMapIterator; +import hivemall.utils.collections.lists.IntArrayList; +import hivemall.utils.collections.maps.IntOpenHashMap; +import hivemall.utils.collections.maps.IntOpenHashMap.IMapIterator; import java.util.BitSet; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/AbstractMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/AbstractMatrix.java b/core/src/main/java/hivemall/math/matrix/AbstractMatrix.java new file mode 100644 index 0000000..2ee27f7 --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/AbstractMatrix.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix; + +import hivemall.math.vector.SparseVector; +import hivemall.math.vector.Vector; +import hivemall.math.vector.VectorProcedure; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +public abstract class AbstractMatrix implements Matrix { + + public AbstractMatrix() {} + + @Override + public double[] row() { + int cols = numColumns(); + return new double[cols]; + } + + @Override + public Vector rowVector() { + return new SparseVector(); + } + + @Override + public final double get(@Nonnegative final int row, @Nonnegative final int col) { + return get(row, col, 0.d); + } + + protected static final void checkRowIndex(final int row, final int numRows) { + if (row < 0 || row >= numRows) { + throw new IndexOutOfBoundsException("Row index " + row + " out of bounds " + numRows); + } + } + + protected static final void checkColIndex(final int col, final int numColumns) { + if (col < 0 || col >= numColumns) { + throw new IndexOutOfBoundsException("Col index " + col + " out of bounds " + numColumns); + } + } + + protected static final void checkIndex(final int index) { + if (index < 0) { + throw new IndexOutOfBoundsException("Invalid index " + index); + } + } + + protected static final void checkIndex(final int row, final int col) { + if (row < 0) { + throw new IndexOutOfBoundsException("Invalid row index " + row); + } + if (col < 0) { + throw new IndexOutOfBoundsException("Invalid col index " + col); + } + } + + protected static final void checkIndex(final int row, final int col, final int numRows, + final int numColumns) { + if (row < 0 || row >= numRows) { + throw new IndexOutOfBoundsException("Row index " + row + " out of bounds " + numRows); + } + if (col < 0 || col >= numColumns) { + throw new IndexOutOfBoundsException("Col index " + col + " out of bounds " + numColumns); + } + } + + @Override + public void eachInRow(final int row, @Nonnull final VectorProcedure procedure) { + eachInRow(row, procedure, true); + } + + @Override + public void eachInColumn(final int col, @Nonnull final VectorProcedure procedure) { + eachInColumn(col, procedure, true); + } + + @Override + public void eachNonNullInRow(final int row, @Nonnull final VectorProcedure procedure) { + eachInRow(row, procedure, false); + } + + @Override + public void eachNonNullInColumn(final int col, @Nonnull final VectorProcedure procedure) { + eachInColumn(col, procedure, false); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/ColumnMajorMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/ColumnMajorMatrix.java b/core/src/main/java/hivemall/math/matrix/ColumnMajorMatrix.java new file mode 100644 index 0000000..51c80aa --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/ColumnMajorMatrix.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix; + +import hivemall.math.vector.VectorProcedure; + +public abstract class ColumnMajorMatrix extends AbstractMatrix { + + public ColumnMajorMatrix() { + super(); + } + + @Override + public boolean isRowMajorMatrix() { + return false; + } + + @Override + public boolean isColumnMajorMatrix() { + return true; + } + + @Override + public void eachInRow(int row, VectorProcedure procedure, boolean nullOutput) { + throw new UnsupportedOperationException(); + } + + @Override + public void eachColumnIndexInRow(int row, VectorProcedure procedure) { + throw new UnsupportedOperationException(); + } + + @Override + public void eachNonZeroInRow(int row, VectorProcedure procedure) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnMajorMatrix toColumnMajorMatrix() { + return this; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/Matrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/Matrix.java b/core/src/main/java/hivemall/math/matrix/Matrix.java new file mode 100644 index 0000000..8a4782a --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/Matrix.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix; + +import hivemall.math.matrix.builders.MatrixBuilder; +import hivemall.math.vector.Vector; +import hivemall.math.vector.VectorProcedure; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import javax.annotation.concurrent.NotThreadSafe; + +/** + * Double matrix. + */ +@NotThreadSafe +public interface Matrix { + + public boolean isSparse(); + + public boolean isRowMajorMatrix(); + + public boolean isColumnMajorMatrix(); + + public boolean readOnly(); + + public boolean swappable(); + + /** The Number of Non-Zeros */ + public int nnz(); + + @Nonnegative + public int numRows(); + + @Nonnegative + public int numColumns(); + + @Nonnegative + public int numColumns(@Nonnegative int row); + + @Nonnull + public double[] row(); + + @Nonnull + public Vector rowVector(); + + @Nonnull + public double[] getRow(@Nonnegative int index); + + /** + * @return returns dst + */ + @Nonnull + public double[] getRow(@Nonnegative int index, @Nonnull double[] dst); + + public void getRow(@Nonnegative int index, @Nonnull Vector row); + + /** + * @throws IndexOutOfBoundsException + */ + public double get(@Nonnegative int row, @Nonnegative int col); + + /** + * @throws IndexOutOfBoundsException + */ + public double get(@Nonnegative int row, @Nonnegative int col, double defaultValue); + + /** + * @throws IndexOutOfBoundsException + * @throws UnsupportedOperationException + */ + public void set(@Nonnegative int row, @Nonnegative int col, double value); + + /** + * @throws IndexOutOfBoundsException + * @throws UnsupportedOperationException + */ + public double getAndSet(@Nonnegative int row, @Nonnegative int col, double value); + + public void swap(@Nonnegative int row1, @Nonnegative int row2); + + public void eachInRow(@Nonnegative int row, @Nonnull VectorProcedure procedure); + + public void eachInRow(@Nonnegative int row, @Nonnull VectorProcedure procedure, + boolean nullOutput); + + public void eachNonNullInRow(@Nonnegative int row, @Nonnull VectorProcedure procedure); + + public void eachNonZeroInRow(@Nonnegative int row, @Nonnull VectorProcedure procedure); + + public void eachColumnIndexInRow(@Nonnegative int row, @Nonnull VectorProcedure procedure); + + public void eachInColumn(@Nonnegative int col, @Nonnull VectorProcedure procedure); + + public void eachInColumn(@Nonnegative int col, @Nonnull VectorProcedure procedure, + boolean nullOutput); + + public void eachNonNullInColumn(@Nonnegative int col, @Nonnull VectorProcedure procedure); + + public void eachNonZeroInColumn(@Nonnegative int col, @Nonnull VectorProcedure procedure); + + @Nonnull + public RowMajorMatrix toRowMajorMatrix(); + + @Nonnull + public ColumnMajorMatrix toColumnMajorMatrix(); + + @Nonnull + public MatrixBuilder builder(); + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/MatrixUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/MatrixUtils.java b/core/src/main/java/hivemall/math/matrix/MatrixUtils.java new file mode 100644 index 0000000..90ce78f --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/MatrixUtils.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix; + +import hivemall.math.matrix.builders.MatrixBuilder; +import hivemall.math.matrix.ints.IntMatrix; +import hivemall.math.vector.VectorProcedure; +import hivemall.utils.lang.Preconditions; +import hivemall.utils.lang.mutable.MutableInt; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +public final class MatrixUtils { + + private MatrixUtils() {} + + @Nonnull + public static Matrix shuffle(@Nonnull final Matrix m, @Nonnull final int[] indices) { + Preconditions.checkArgument(m.numRows() <= indices.length, "m.numRow() `" + m.numRows() + + "` MUST be equals to or less than |swapIndicies| `" + indices.length + "`"); + + final MatrixBuilder builder = m.builder(); + final VectorProcedure proc = new VectorProcedure() { + public void apply(int col, double value) { + builder.nextColumn(col, value); + } + }; + for (int i = 0; i < indices.length; i++) { + int idx = indices[i]; + m.eachNonNullInRow(idx, proc); + builder.nextRow(); + } + return builder.buildMatrix(); + } + + /** + * Returns the index of maximum value of an array. + * + * @return -1 if there are no columns + */ + public static int whichMax(@Nonnull final IntMatrix matrix, @Nonnegative final int row) { + final MutableInt m = new MutableInt(Integer.MIN_VALUE); + final MutableInt which = new MutableInt(-1); + matrix.eachInRow(row, new VectorProcedure() { + @Override + public void apply(int i, int value) { + if (value > m.getValue()) { + m.setValue(value); + which.setValue(i); + } + } + }, false); + return which.getValue(); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/RowMajorMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/RowMajorMatrix.java b/core/src/main/java/hivemall/math/matrix/RowMajorMatrix.java new file mode 100644 index 0000000..2c611bd --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/RowMajorMatrix.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix; + +import hivemall.math.vector.Vector; +import hivemall.math.vector.VectorProcedure; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +public abstract class RowMajorMatrix extends AbstractMatrix { + + public RowMajorMatrix() { + super(); + } + + @Override + public boolean isRowMajorMatrix() { + return true; + } + + @Override + public boolean isColumnMajorMatrix() { + return false; + } + + @Override + public void getRow(@Nonnegative final int index, @Nonnull final Vector row) { + row.clear(); + eachNonNullInRow(index, new VectorProcedure() { + @Override + public void apply(final int i, final double value) { + row.set(i, value); + } + }); + } + + @Override + public void eachInColumn(int col, VectorProcedure procedure, boolean nullOutput) { + throw new UnsupportedOperationException(); + } + + @Override + public void eachNonZeroInColumn(int col, VectorProcedure procedure) { + throw new UnsupportedOperationException(); + } + + @Override + public RowMajorMatrix toRowMajorMatrix() { + return this; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/builders/CSCMatrixBuilder.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/builders/CSCMatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/CSCMatrixBuilder.java new file mode 100644 index 0000000..df2bff7 --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/builders/CSCMatrixBuilder.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix.builders; + +import hivemall.math.matrix.sparse.CSCMatrix; +import hivemall.utils.collections.lists.DoubleArrayList; +import hivemall.utils.collections.lists.IntArrayList; + +import java.util.Arrays; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +public final class CSCMatrixBuilder extends MatrixBuilder { + + @Nonnull + private final IntArrayList rows; + @Nonnull + private final IntArrayList cols; + @Nonnull + private final DoubleArrayList values; + + private int row; + private int maxNumColumns; + + public CSCMatrixBuilder(int initSize) { + super(); + this.rows = new IntArrayList(initSize); + this.cols = new IntArrayList(initSize); + this.values = new DoubleArrayList(initSize); + this.row = 0; + this.maxNumColumns = 0; + } + + @Override + public CSCMatrixBuilder nextRow() { + row++; + return this; + } + + @Override + public CSCMatrixBuilder nextColumn(@Nonnegative final int col, final double value) { + rows.add(row); + cols.add(col); + values.add((float) value); + this.maxNumColumns = Math.max(col + 1, maxNumColumns); + return this; + } + + @Override + public CSCMatrix buildMatrix() { + if (rows.isEmpty() || cols.isEmpty()) { + throw new IllegalStateException("No element in the matrix"); + } + + final int[] columnIndices = cols.toArray(true); + final int[] rowsIndicies = rows.toArray(true); + final double[] valuesArray = values.toArray(true); + + // convert to column major + final int nnz = valuesArray.length; + SortObj[] sortObjs = new SortObj[nnz]; + for (int i = 0; i < nnz; i++) { + sortObjs[i] = new SortObj(columnIndices[i], rowsIndicies[i], valuesArray[i]); + } + Arrays.sort(sortObjs); + for (int i = 0; i < nnz; i++) { + columnIndices[i] = sortObjs[i].columnIndex; + rowsIndicies[i] = sortObjs[i].rowsIndex; + valuesArray[i] = sortObjs[i].value; + } + sortObjs = null; + + final int[] columnPointers = new int[maxNumColumns + 1]; + int prevCol = -1; + for (int j = 0; j < columnIndices.length; j++) { + int currCol = columnIndices[j]; + if (currCol != prevCol) { + columnPointers[currCol] = j; + prevCol = currCol; + } + } + columnPointers[maxNumColumns] = nnz; // nnz + + return new CSCMatrix(columnPointers, rowsIndicies, valuesArray, row, maxNumColumns); + } + + private static final class SortObj implements Comparable<SortObj> { + final int columnIndex; + final int rowsIndex; + final double value; + + SortObj(int columnIndex, int rowsIndex, double value) { + this.columnIndex = columnIndex; + this.rowsIndex = rowsIndex; + this.value = value; + } + + @Override + public int compareTo(SortObj o) { + return Integer.compare(columnIndex, o.columnIndex); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/builders/CSRMatrixBuilder.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/builders/CSRMatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/CSRMatrixBuilder.java new file mode 100644 index 0000000..2467056 --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/builders/CSRMatrixBuilder.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix.builders; + +import hivemall.math.matrix.sparse.CSRMatrix; +import hivemall.utils.collections.lists.DoubleArrayList; +import hivemall.utils.collections.lists.IntArrayList; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +/** + * Compressed Sparse Row Matrix builder. + */ +public final class CSRMatrixBuilder extends MatrixBuilder { + + @Nonnull + private final IntArrayList rowPointers; + @Nonnull + private final IntArrayList columnIndices; + @Nonnull + private final DoubleArrayList values; + + private int maxNumColumns; + + public CSRMatrixBuilder(@Nonnegative int initSize) { + super(); + this.rowPointers = new IntArrayList(initSize + 1); + rowPointers.add(0); + this.columnIndices = new IntArrayList(initSize); + this.values = new DoubleArrayList(initSize); + this.maxNumColumns = 0; + } + + @Override + public CSRMatrixBuilder nextRow() { + int ptr = values.size(); + rowPointers.add(ptr); + return this; + } + + @Override + public CSRMatrixBuilder nextColumn(@Nonnegative int col, double value) { + if (value == 0.d) { + return this; + } + + columnIndices.add(col); + values.add(value); + this.maxNumColumns = Math.max(col + 1, maxNumColumns); + return this; + } + + @Override + public CSRMatrix buildMatrix() { + CSRMatrix matrix = new CSRMatrix(rowPointers.toArray(true), columnIndices.toArray(true), + values.toArray(true), maxNumColumns); + return matrix; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/builders/ColumnMajorDenseMatrixBuilder.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/builders/ColumnMajorDenseMatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/ColumnMajorDenseMatrixBuilder.java new file mode 100644 index 0000000..9cae1c7 --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/builders/ColumnMajorDenseMatrixBuilder.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix.builders; + +import hivemall.math.matrix.dense.ColumnMajorDenseMatrix2d; +import hivemall.utils.collections.arrays.SparseDoubleArray; +import hivemall.utils.collections.maps.IntOpenHashTable; +import hivemall.utils.collections.maps.IntOpenHashTable.IMapIterator; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +public final class ColumnMajorDenseMatrixBuilder extends MatrixBuilder { + + @Nonnull + private final IntOpenHashTable<SparseDoubleArray> col2rows; + private int row; + private int maxNumColumns; + private int nnz; + + public ColumnMajorDenseMatrixBuilder(int initSize) { + this.col2rows = new IntOpenHashTable<SparseDoubleArray>(initSize); + this.row = 0; + this.maxNumColumns = 0; + this.nnz = 0; + } + + @Override + public ColumnMajorDenseMatrixBuilder nextRow() { + row++; + return this; + } + + @Override + public ColumnMajorDenseMatrixBuilder nextColumn(@Nonnegative final int col, final double value) { + if (value == 0.d) { + return this; + } + + SparseDoubleArray rows = col2rows.get(col); + if (rows == null) { + rows = new SparseDoubleArray(4); + col2rows.put(col, rows); + } + rows.put(row, value); + this.maxNumColumns = Math.max(col + 1, maxNumColumns); + nnz++; + return this; + } + + @Override + public ColumnMajorDenseMatrix2d buildMatrix() { + final double[][] data = new double[maxNumColumns][]; + + final IMapIterator<SparseDoubleArray> itor = col2rows.entries(); + while (itor.next() != -1) { + int col = itor.getKey(); + SparseDoubleArray rows = itor.getValue(); + data[col] = rows.toArray(); + } + + return new ColumnMajorDenseMatrix2d(data, row, nnz); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/builders/DoKMatrixBuilder.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/builders/DoKMatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/DoKMatrixBuilder.java new file mode 100644 index 0000000..556a8d8 --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/builders/DoKMatrixBuilder.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix.builders; + +import hivemall.math.matrix.sparse.DoKMatrix; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +public final class DoKMatrixBuilder extends MatrixBuilder { + + @Nonnull + private final DoKMatrix matrix; + + private int row; + + public DoKMatrixBuilder(@Nonnegative int initSize) { + super(); + this.row = 0; + this.matrix = new DoKMatrix(initSize); + } + + @Override + public DoKMatrixBuilder nextRow() { + row++; + return this; + } + + @Override + public DoKMatrixBuilder nextColumn(@Nonnegative final int col, final double value) { + matrix.set(row, col, value); + return this; + } + + @Override + public DoKMatrix buildMatrix() { + return matrix; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/builders/MatrixBuilder.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/builders/MatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/MatrixBuilder.java new file mode 100644 index 0000000..66bd1e2 --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/builders/MatrixBuilder.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix.builders; + +import hivemall.math.matrix.Matrix; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +public abstract class MatrixBuilder { + + public MatrixBuilder() {} + + public void nextRow(@Nonnull final double[] row) { + for (int col = 0; col < row.length; col++) { + nextColumn(col, row[col]); + } + nextRow(); + } + + public void nextRow(@Nonnull final String[] row) { + for (String col : row) { + if (col == null) { + continue; + } + nextColumn(col); + } + nextRow(); + } + + @Nonnull + public abstract MatrixBuilder nextRow(); + + @Nonnull + public abstract MatrixBuilder nextColumn(@Nonnegative int col, double value); + + /** + * @throws IllegalArgumentException + * @throws NumberFormatException + */ + @Nonnull + public MatrixBuilder nextColumn(@Nonnull final String col) { + final int pos = col.indexOf(':'); + if (pos == 0) { + throw new IllegalArgumentException("Invalid feature value representation: " + col); + } + + final String feature; + final double value; + if (pos > 0) { + feature = col.substring(0, pos); + String s2 = col.substring(pos + 1); + value = Double.parseDouble(s2); + } else { + feature = col; + value = 1.d; + } + + if (feature.indexOf(':') != -1) { + throw new IllegalArgumentException("Invaliad feature format `<index>:<value>`: " + col); + } + + int colIndex = Integer.parseInt(feature); + if (colIndex < 0) { + throw new IllegalArgumentException("Col index MUST be greather than or equals to 0: " + + colIndex); + } + + return nextColumn(colIndex, value); + } + + @Nonnull + public abstract Matrix buildMatrix(); + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/builders/RowMajorDenseMatrixBuilder.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/builders/RowMajorDenseMatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/RowMajorDenseMatrixBuilder.java new file mode 100644 index 0000000..b6d0588 --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/builders/RowMajorDenseMatrixBuilder.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix.builders; + +import hivemall.math.matrix.dense.RowMajorDenseMatrix2d; +import hivemall.utils.collections.arrays.SparseDoubleArray; + +import java.util.ArrayList; +import java.util.List; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +public final class RowMajorDenseMatrixBuilder extends MatrixBuilder { + + @Nonnull + private final List<double[]> rows; + private int maxNumColumns; + private int nnz; + + @Nonnull + private final SparseDoubleArray rowProbe; + + public RowMajorDenseMatrixBuilder(@Nonnegative int initSize) { + super(); + this.rows = new ArrayList<double[]>(initSize); + this.maxNumColumns = 0; + this.nnz = 0; + this.rowProbe = new SparseDoubleArray(32); + } + + @Override + public RowMajorDenseMatrixBuilder nextColumn(@Nonnegative final int col, final double value) { + if (value == 0.d) { + return this; + } + rowProbe.put(col, value); + nnz++; + return this; + } + + @Override + public RowMajorDenseMatrixBuilder nextRow() { + double[] row = rowProbe.toArray(); + rowProbe.clear(); + nextRow(row); + return this; + } + + @Override + public void nextRow(@Nonnull double[] row) { + rows.add(row); + this.maxNumColumns = Math.max(row.length, maxNumColumns); + } + + @Override + public RowMajorDenseMatrix2d buildMatrix() { + int numRows = rows.size(); + double[][] data = rows.toArray(new double[numRows][]); + return new RowMajorDenseMatrix2d(data, maxNumColumns, nnz); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/dense/ColumnMajorDenseMatrix2d.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/dense/ColumnMajorDenseMatrix2d.java b/core/src/main/java/hivemall/math/matrix/dense/ColumnMajorDenseMatrix2d.java new file mode 100644 index 0000000..2c5fd45 --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/dense/ColumnMajorDenseMatrix2d.java @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix.dense; + +import hivemall.math.matrix.ColumnMajorMatrix; +import hivemall.math.matrix.builders.ColumnMajorDenseMatrixBuilder; +import hivemall.math.vector.Vector; +import hivemall.math.vector.VectorProcedure; +import hivemall.utils.lang.Preconditions; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +/** + * Fixed-size Dense 2-d double Matrix. + */ +public final class ColumnMajorDenseMatrix2d extends ColumnMajorMatrix { + + @Nonnull + private final double[][] data; // col-row + + @Nonnegative + private final int numRows; + @Nonnegative + private final int numColumns; + @Nonnegative + private int nnz; + + public ColumnMajorDenseMatrix2d(@Nonnull double[][] data, @Nonnegative int numRows) { + this(data, numRows, nnz(data)); + } + + public ColumnMajorDenseMatrix2d(@Nonnull double[][] data, @Nonnegative int numRows, + @Nonnegative int nnz) { + super(); + this.data = data; + this.numRows = numRows; + this.numColumns = data.length; + this.nnz = nnz; + } + + @Override + public boolean isSparse() { + return false; + } + + @Override + public boolean readOnly() { + return true; + } + + @Override + public boolean swappable() { + return false; + } + + @Override + public int nnz() { + return nnz; + } + + @Override + public int numRows() { + return numRows; + } + + @Override + public int numColumns() { + return numColumns; + } + + @Override + public int numColumns(final int row) { + checkRowIndex(row, numRows); + + int numColumns = 0; + for (int j = 0; j < data.length; j++) { + final double[] col = data[j]; + if (col == null) { + continue; + } + if (row < col.length && col[row] != 0.d) { + numColumns++; + } + } + return numColumns; + } + + @Override + public double[] getRow(final int index) { + checkRowIndex(index, numRows); + + double[] row = new double[numColumns]; + return getRow(index, row); + } + + @Override + public double[] getRow(final int index, @Nonnull final double[] dst) { + checkRowIndex(index, numRows); + + for (int j = 0; j < data.length; j++) { + final double[] col = data[j]; + if (col == null) { + continue; + } + if (index < col.length) { + dst[j] = col[index]; + } + } + return dst; + } + + @Override + public void getRow(final int index, @Nonnull final Vector row) { + checkRowIndex(index, numRows); + row.clear(); + + for (int j = 0; j < data.length; j++) { + final double[] col = data[j]; + if (col == null) { + continue; + } + if (index < col.length) { + double v = col[index]; + row.set(j, v); + } + } + } + + @Override + public double get(final int row, final int col, final double defaultValue) { + checkIndex(row, col, numRows, numColumns); + + final double[] colData = data[col]; + if (colData == null || row >= colData.length) { + return defaultValue; + } + return colData[row]; + } + + @Override + public double getAndSet(final int row, final int col, final double value) { + checkIndex(row, col, numRows, numColumns); + + final double[] colData = data[col]; + Preconditions.checkNotNull(colData, "col does not exists: " + col); + checkRowIndex(row, colData.length); + + final double old = colData[row]; + colData[row] = value; + if (old == 0.d && value != 0.d) { + ++nnz; + } + return old; + } + + @Override + public void set(final int row, final int col, final double value) { + checkIndex(row, col, numRows, numColumns); + if (value == 0.d) { + return; + } + + final double[] colData = data[col]; + Preconditions.checkNotNull(colData, "col does not exists: " + col); + checkRowIndex(row, colData.length); + + if (colData[row] == 0.d) { + ++nnz; + } + colData[row] = value; + } + + @Override + public void swap(int row1, int row2) { + throw new UnsupportedOperationException(); + } + + @Override + public void eachInColumn(final int col, @Nonnull final VectorProcedure procedure, + final boolean nullOutput) { + checkColIndex(col, numColumns); + + final double[] colData = data[col]; + if (colData == null) { + if (nullOutput) { + for (int i = 0; i < numRows; i++) { + procedure.apply(i, 0.d); + } + } + return; + } + + int row = 0; + for (int len = colData.length; row < len; row++) { + procedure.apply(row, colData[row]); + } + if (nullOutput) { + for (; row < numRows; row++) { + procedure.apply(row, 0.d); + } + } + } + + @Override + public void eachNonZeroInColumn(final int col, @Nonnull final VectorProcedure procedure) { + checkColIndex(col, numColumns); + + final double[] colData = data[col]; + if (colData == null) { + return; + } + int row = 0; + for (int len = colData.length; row < len; row++) { + final double v = colData[row]; + if (v != 0.d) { + procedure.apply(row, v); + } + } + } + + @Override + public RowMajorDenseMatrix2d toRowMajorMatrix() { + final double[][] rowcol = new double[numRows][numColumns]; + int nnz = 0; + for (int j = 0; j < data.length; j++) { + final double[] colData = data[j]; + if (colData == null) { + continue; + } + for (int i = 0; i < colData.length; i++) { + final double v = colData[i]; + if (v == 0.d) { + continue; + } + rowcol[i][j] = v; + nnz++; + } + } + for (int i = 0; i < rowcol.length; i++) { + final double[] row = rowcol[i]; + final int last = numColumns - 1; + int maxj = last; + for (; maxj >= 0; maxj--) { + if (row[maxj] != 0.d) { + break; + } + } + if (maxj == last) { + continue; + } else if (maxj < 0) { + rowcol[i] = null; + continue; + } + final double[] dstRow = new double[maxj + 1]; + System.arraycopy(row, 0, dstRow, 0, dstRow.length); + rowcol[i] = dstRow; + } + + return new RowMajorDenseMatrix2d(rowcol, numColumns, nnz); + } + + @Override + public ColumnMajorDenseMatrixBuilder builder() { + return new ColumnMajorDenseMatrixBuilder(numColumns); + } + + private static int nnz(@Nonnull final double[][] data) { + int count = 0; + for (int j = 0; j < data.length; j++) { + final double[] col = data[j]; + if (col == null) { + continue; + } + for (int i = 0; i < col.length; i++) { + if (col[i] != 0.d) { + ++count; + } + } + } + return count; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/dense/RowMajorDenseMatrix2d.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/dense/RowMajorDenseMatrix2d.java b/core/src/main/java/hivemall/math/matrix/dense/RowMajorDenseMatrix2d.java new file mode 100644 index 0000000..54302e1 --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/dense/RowMajorDenseMatrix2d.java @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix.dense; + +import hivemall.math.matrix.RowMajorMatrix; +import hivemall.math.matrix.builders.RowMajorDenseMatrixBuilder; +import hivemall.math.vector.DenseVector; +import hivemall.math.vector.VectorProcedure; +import hivemall.utils.lang.Preconditions; + +import java.util.Arrays; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +/** + * Fixed-size Dense 2-d double Matrix. + */ +public final class RowMajorDenseMatrix2d extends RowMajorMatrix { + + @Nonnull + private final double[][] data; + + @Nonnegative + private final int numRows; + @Nonnegative + private final int numColumns; + @Nonnegative + private int nnz; + + public RowMajorDenseMatrix2d(@Nonnull double[][] data, @Nonnegative int numColumns) { + this(data, numColumns, nnz(data)); + } + + public RowMajorDenseMatrix2d(@Nonnull double[][] data, @Nonnegative int numColumns, + @Nonnegative int nnz) { + super(); + this.data = data; + this.numRows = data.length; + this.numColumns = numColumns; + this.nnz = nnz; + } + + @Override + public boolean isSparse() { + return false; + } + + @Override + public boolean readOnly() { + return true; + } + + @Override + public boolean swappable() { + return true; + } + + @Override + public int nnz() { + return nnz; + } + + @Override + public int numRows() { + return numRows; + } + + @Override + public int numColumns() { + return numColumns; + } + + @Override + public int numColumns(@Nonnegative final int row) { + checkRowIndex(row, numRows); + + final double[] r = data[row]; + if (r == null) { + return 0; + } + return r.length; + } + + @Override + public DenseVector rowVector() { + return new DenseVector(numColumns); + } + + @Override + public double[] getRow(@Nonnegative final int index) { + checkRowIndex(index, numRows); + + final double[] row = data[index]; + if (row == null) { + return new double[0]; + } else if (row.length == numRows) { + return row; + } + + final double[] result = new double[numRows]; + System.arraycopy(row, 0, result, 0, row.length); + return result; + } + + @Override + public double[] getRow(@Nonnull final int index, @Nonnull final double[] dst) { + checkRowIndex(index, numRows); + + final double[] row = data[index]; + if (row == null) { + return new double[0]; + } + + System.arraycopy(row, 0, dst, 0, row.length); + if (dst.length > row.length) {// zerofill + Arrays.fill(dst, row.length, dst.length, 0.d); + } + return dst; + } + + @Override + public double get(@Nonnegative final int row, @Nonnegative final int col, + final double defaultValue) { + checkIndex(row, col, numRows, numColumns); + + final double[] rowData = data[row]; + if (rowData == null || col >= rowData.length) { + return defaultValue; + } + return rowData[col]; + } + + @Override + public double getAndSet(@Nonnegative final int row, @Nonnegative final int col, + final double value) { + checkIndex(row, col, numRows, numColumns); + + final double[] rowData = data[row]; + Preconditions.checkNotNull(rowData, "row does not exists: " + row); + checkColIndex(col, rowData.length); + + double old = rowData[col]; + rowData[col] = value; + if (old == 0.d && value != 0.d) { + ++nnz; + } + return old; + } + + @Override + public void set(@Nonnegative final int row, @Nonnegative final int col, final double value) { + checkIndex(row, col, numRows, numColumns); + if (value == 0.d) { + return; + } + + final double[] rowData = data[row]; + Preconditions.checkNotNull(rowData, "row does not exists: " + row); + checkColIndex(col, rowData.length); + + if (rowData[col] == 0.d) { + ++nnz; + } + rowData[col] = value; + } + + @Override + public void swap(@Nonnegative final int row1, @Nonnegative final int row2) { + checkRowIndex(row1, numRows); + checkRowIndex(row2, numRows); + + double[] oldRow1 = data[row1]; + data[row1] = data[row2]; + data[row2] = oldRow1; + } + + @Override + public void eachInRow(@Nonnegative final int row, @Nonnull final VectorProcedure procedure, + final boolean nullOutput) { + checkRowIndex(row, numRows); + + final double[] rowData = data[row]; + if (rowData == null) { + if (nullOutput) { + for (int j = 0; j < numColumns; j++) { + procedure.apply(j, 0.d); + } + } + return; + } + + int col = 0; + for (int len = rowData.length; col < len; col++) { + procedure.apply(col, rowData[col]); + } + if (nullOutput) { + for (; col < numColumns; col++) { + procedure.apply(col, 0.d); + } + } + } + + @Override + public void eachNonZeroInRow(@Nonnegative final int row, + @Nonnull final VectorProcedure procedure) { + checkRowIndex(row, numRows); + + final double[] rowData = data[row]; + if (rowData == null) { + return; + } + for (int col = 0, len = rowData.length; col < len; col++) { + final double v = rowData[col]; + if (v != 0.d) { + procedure.apply(col, v); + } + } + } + + @Override + public void eachColumnIndexInRow(@Nonnegative final int row, + @Nonnull final VectorProcedure procedure) { + checkRowIndex(row, numRows); + + final double[] rowData = data[row]; + if (rowData == null) { + return; + } + for (int col = 0, len = rowData.length; col < len; col++) { + procedure.apply(col); + } + } + + @Override + public void eachInColumn(@Nonnegative final int col, @Nonnull final VectorProcedure procedure, + final boolean nullOutput) { + checkColIndex(col, numColumns); + + for (int row = 0; row < numRows; row++) { + final double[] rowData = data[row]; + if (rowData != null && col < rowData.length) { + procedure.apply(row, rowData[col]); + } else { + if (nullOutput) { + procedure.apply(row, 0.d); + } + } + } + } + + @Override + public void eachNonZeroInColumn(@Nonnegative final int col, + @Nonnull final VectorProcedure procedure) { + checkColIndex(col, numColumns); + + for (int row = 0; row < numRows; row++) { + final double[] rowData = data[row]; + if (rowData == null) { + continue; + } + if (col < rowData.length) { + final double v = rowData[col]; + if (v != 0.d) { + procedure.apply(row, v); + } + } + } + } + + @Override + public ColumnMajorDenseMatrix2d toColumnMajorMatrix() { + final double[][] colrow = new double[numColumns][numRows]; + int nnz = 0; + for (int i = 0; i < data.length; i++) { + final double[] rowData = data[i]; + if (rowData == null) { + continue; + } + for (int j = 0; j < rowData.length; j++) { + final double v = rowData[j]; + if (v == 0.d) { + continue; + } + colrow[j][i] = v; + nnz++; + } + } + for (int j = 0; j < colrow.length; j++) { + final double[] col = colrow[j]; + final int last = numRows - 1; + int maxi = last; + for (; maxi >= 0; maxi--) { + if (col[maxi] != 0.d) { + break; + } + } + if (maxi == last) { + continue; + } else if (maxi < 0) { + colrow[j] = null; + continue; + } + final double[] dstCol = new double[maxi + 1]; + System.arraycopy(col, 0, dstCol, 0, dstCol.length); + colrow[j] = dstCol; + } + + return new ColumnMajorDenseMatrix2d(colrow, numRows, nnz); + } + + @Override + public RowMajorDenseMatrixBuilder builder() { + return new RowMajorDenseMatrixBuilder(numRows); + } + + private static int nnz(@Nonnull final double[][] data) { + int count = 0; + for (int i = 0; i < data.length; i++) { + final double[] row = data[i]; + if (row == null) { + continue; + } + for (int j = 0; j < row.length; j++) { + if (row[j] != 0.d) { + ++count; + } + } + } + return count; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/ints/AbstractIntMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/ints/AbstractIntMatrix.java b/core/src/main/java/hivemall/math/matrix/ints/AbstractIntMatrix.java new file mode 100644 index 0000000..0431310 --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/ints/AbstractIntMatrix.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix.ints; + +import hivemall.math.vector.VectorProcedure; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +public abstract class AbstractIntMatrix implements IntMatrix { + + protected int defaultValue; + + public AbstractIntMatrix() { + this.defaultValue = 0; + } + + @Override + public void setDefaultValue(int value) { + this.defaultValue = value; + } + + @Override + public int[] row() { + int size = numRows(); + return new int[size]; + } + + @Override + public final int get(@Nonnegative final int row, @Nonnegative final int col) { + return get(row, col, defaultValue); + } + + @Override + public void incr(@Nonnegative final int row, @Nonnegative final int col) { + incr(row, col, 1); + } + + protected static final void checkRowIndex(final int row, final int numRows) { + if (row < 0 || row >= numRows) { + throw new IndexOutOfBoundsException("Row index " + row + " out of bounds " + numRows); + } + } + + protected static final void checkColIndex(final int col, final int numColumns) { + if (col < 0 || col >= numColumns) { + throw new IndexOutOfBoundsException("Col index " + col + " out of bounds " + numColumns); + } + } + + protected static final void checkIndex(final int index) { + if (index < 0) { + throw new IllegalArgumentException("Invalid index: " + index); + } + } + + protected static final void checkIndex(final int row, final int col) { + if (row < 0) { + throw new IllegalArgumentException("Invalid row index: " + row); + } + if (col < 0) { + throw new IllegalArgumentException("Invalid col index: " + col); + } + } + + protected static final void checkIndex(final int row, final int col, final int numRows, + final int numColumns) { + if (row < 0 || row >= numRows) { + throw new IndexOutOfBoundsException("Row index " + row + " out of bounds " + numRows); + } + if (col < 0 || col >= numColumns) { + throw new IndexOutOfBoundsException("Col index " + col + " out of bounds " + numColumns); + } + } + + @Override + public void eachInRow(final int row, @Nonnull final VectorProcedure procedure) { + eachInRow(row, procedure, true); + } + + @Override + public void eachInColumn(final int col, @Nonnull final VectorProcedure procedure) { + eachInColumn(col, procedure, true); + } + + @Override + public void eachNonNullInRow(final int row, @Nonnull final VectorProcedure procedure) { + eachInRow(row, procedure, false); + } + + @Override + public void eachNonNullInColumn(final int col, @Nonnull final VectorProcedure procedure) { + eachInColumn(col, procedure, false); + } + +}