Close #121: [HIVEMALL-151] Support Matrix conversion from DoK to CSR/CSC matrix
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/fdf70214 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/fdf70214 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/fdf70214 Branch: refs/heads/master Commit: fdf70214359f3ce2b1371edf630be89ba9499745 Parents: d4f4ab9 Author: Makoto Yui <[email protected]> Authored: Mon Oct 16 20:35:00 2017 +0900 Committer: Makoto Yui <[email protected]> Committed: Mon Oct 16 20:35:00 2017 +0900 ---------------------------------------------------------------------- .travis.yml | 3 +- .../anomaly/SingularSpectrumTransform.java | 6 +- .../java/hivemall/ftvec/AddFeatureIndexUDF.java | 2 +- .../java/hivemall/ftvec/FeatureIndexUDF.java | 2 +- .../ftvec/trans/AddFieldIndicesUDF.java | 6 +- .../hivemall/ftvec/trans/FFMFeaturesUDF.java | 10 +- .../hivemall/math/matrix/AbstractMatrix.java | 5 + .../math/matrix/ColumnMajorFloatMatrix.java | 32 ++ .../java/hivemall/math/matrix/FloatMatrix.java | 73 ++++ .../main/java/hivemall/math/matrix/Matrix.java | 2 + .../java/hivemall/math/matrix/MatrixUtils.java | 264 +++++++++++- .../math/matrix/RowMajorFloatMatrix.java | 32 ++ .../math/matrix/builders/CSCMatrixBuilder.java | 8 +- .../hivemall/math/matrix/sparse/CSCMatrix.java | 47 ++- .../hivemall/math/matrix/sparse/CSRMatrix.java | 6 +- .../math/matrix/sparse/DoKFloatMatrix.java | 368 ----------------- .../hivemall/math/matrix/sparse/DoKMatrix.java | 37 +- .../matrix/sparse/floats/CSCFloatMatrix.java | 317 +++++++++++++++ .../matrix/sparse/floats/CSRFloatMatrix.java | 293 ++++++++++++++ .../matrix/sparse/floats/DoKFloatMatrix.java | 401 +++++++++++++++++++ .../hivemall/math/vector/AbstractVector.java | 10 + .../hivemall/math/vector/DenseFloatVector.java | 107 +++++ .../hivemall/math/vector/SparseFloatVector.java | 86 ++++ .../main/java/hivemall/math/vector/Vector.java | 7 + .../hivemall/math/vector/VectorProcedure.java | 4 + .../main/java/hivemall/recommend/SlimUDTF.java | 18 +- .../hivemall/smile/utils/SmileExtUtils.java | 12 +- .../collections/arrays/SparseFloatArray.java | 9 + .../hivemall/math/matrix/MatrixUtilsTest.java | 132 ++++++ .../math/matrix/sparse/DoKFloatMatrixTest.java | 43 -- .../math/matrix/sparse/DoKMatrixTest.java | 43 ++ .../sparse/floats/DoKFloatMatrixTest.java | 60 +++ docs/gitbook/getting_started/installation.md | 2 +- pom.xml | 4 +- resources/ddl/define-all-as-permanent.hive | 4 + resources/ddl/define-all.hive | 4 + resources/ddl/define-all.spark | 4 + resources/ddl/define-udfs.td.hql | 3 +- .../ftvec/AddFeatureIndexUDFWrapper.java | 2 +- 39 files changed, 1992 insertions(+), 476 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/.travis.yml ---------------------------------------------------------------------- diff --git a/.travis.yml b/.travis.yml index c64c5ff..c98fe0c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -18,9 +18,10 @@ env: language: java jdk: - - openjdk7 +# - openjdk7 # - oraclejdk7 - oraclejdk8 +# - oraclejdk9 branches: only: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/anomaly/SingularSpectrumTransform.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/anomaly/SingularSpectrumTransform.java b/core/src/main/java/hivemall/anomaly/SingularSpectrumTransform.java index 34d85aa..1936da4 100644 --- a/core/src/main/java/hivemall/anomaly/SingularSpectrumTransform.java +++ b/core/src/main/java/hivemall/anomaly/SingularSpectrumTransform.java @@ -186,14 +186,14 @@ final class SingularSpectrumTransform implements SingularSpectrumTransformInterf for (int i = 0; i < k; i++) { map.put(eigvals[i], i); } - Iterator<Integer> indicies = map.values().iterator(); + Iterator<Integer> indices = map.values().iterator(); double s = 0.d; for (int i = 0; i < r; i++) { - if (!indicies.hasNext()) { + if (!indices.hasNext()) { throw new IllegalStateException("Should not happen"); } - double v = eigvecs.getEntry(0, indicies.next().intValue()); + double v = eigvecs.getEntry(0, indices.next().intValue()); s += v * v; } return 1.d - Math.sqrt(s); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/ftvec/AddFeatureIndexUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/AddFeatureIndexUDF.java b/core/src/main/java/hivemall/ftvec/AddFeatureIndexUDF.java index 105dd2a..21b3514 100644 --- a/core/src/main/java/hivemall/ftvec/AddFeatureIndexUDF.java +++ b/core/src/main/java/hivemall/ftvec/AddFeatureIndexUDF.java @@ -37,7 +37,7 @@ import org.apache.hadoop.io.Text; */ @Description( name = "add_feature_index", - value = "_FUNC_(ARRAY[DOUBLE]: dense feature vector) - Returns a feature vector with feature indicies") + value = "_FUNC_(ARRAY[DOUBLE]: dense feature vector) - Returns a feature vector with feature indices") @UDFType(deterministic = true, stateful = false) public final class AddFeatureIndexUDF extends UDF { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/ftvec/FeatureIndexUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/FeatureIndexUDF.java b/core/src/main/java/hivemall/ftvec/FeatureIndexUDF.java index 9ffe6c6..9fdbc01 100644 --- a/core/src/main/java/hivemall/ftvec/FeatureIndexUDF.java +++ b/core/src/main/java/hivemall/ftvec/FeatureIndexUDF.java @@ -32,7 +32,7 @@ import org.apache.hadoop.io.IntWritable; @Description( name = "feature_index", - value = "_FUNC_(feature_vector in array<string>) - Returns feature indicies in array<index>") + value = "_FUNC_(feature_vector in array<string>) - Returns feature indices in array<index>") @UDFType(deterministic = true, stateful = false) public final class FeatureIndexUDF extends UDF { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/ftvec/trans/AddFieldIndicesUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/trans/AddFieldIndicesUDF.java b/core/src/main/java/hivemall/ftvec/trans/AddFieldIndicesUDF.java index 53b998c..99cf785 100644 --- a/core/src/main/java/hivemall/ftvec/trans/AddFieldIndicesUDF.java +++ b/core/src/main/java/hivemall/ftvec/trans/AddFieldIndicesUDF.java @@ -37,8 +37,8 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; -@Description(name = "add_field_indicies", value = "_FUNC_(array<string> features) " - + "- Returns arrays of string that field indicies (<field>:<feature>)* are argumented") +@Description(name = "add_field_indices", value = "_FUNC_(array<string> features) " + + "- Returns arrays of string that field indices (<field>:<feature>)* are argumented") @UDFType(deterministic = true, stateful = false) public final class AddFieldIndicesUDF extends GenericUDF { @@ -82,7 +82,7 @@ public final class AddFieldIndicesUDF extends GenericUDF { @Override public String getDisplayString(String[] args) { - return "add_field_indicies( " + Arrays.toString(args) + " )"; + return "add_field_indices( " + Arrays.toString(args) + " )"; } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/ftvec/trans/FFMFeaturesUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/trans/FFMFeaturesUDF.java b/core/src/main/java/hivemall/ftvec/trans/FFMFeaturesUDF.java index eead738..a0acd36 100644 --- a/core/src/main/java/hivemall/ftvec/trans/FFMFeaturesUDF.java +++ b/core/src/main/java/hivemall/ftvec/trans/FFMFeaturesUDF.java @@ -60,7 +60,7 @@ public final class FFMFeaturesUDF extends UDFWithOptions { private boolean _mhash = true; private int _numFeatures = Feature.DEFAULT_NUM_FEATURES; private int _numFields = Feature.DEFAULT_NUM_FIELDS; - private boolean _emitIndicies = false; + private boolean _emitIndices = false; @Override protected Options getOptions() { @@ -72,7 +72,7 @@ public final class FFMFeaturesUDF extends UDFWithOptions { opts.addOption("hash", "feature_hashing", true, "The number of bits for feature hashing in range [18,31] [default:21]"); opts.addOption("fields", "num_fields", true, "The number of fields [default:1024]"); - opts.addOption("emit_indicies", false, "Emit indicies for fields [default: false]"); + opts.addOption("emit_indices", false, "Emit indices for fields [default: false]"); return opts; } @@ -100,7 +100,7 @@ public final class FFMFeaturesUDF extends UDFWithOptions { } this._numFields = numFields; - this._emitIndicies = cl.hasOption("emit_indicies"); + this._emitIndices = cl.hasOption("emit_indices"); return cl; } @@ -189,14 +189,14 @@ public final class FFMFeaturesUDF extends UDFWithOptions { // categorical feature representation final String fv; if (_mhash) { - int field = _emitIndicies ? i : MurmurHash3.murmurhash3(_featureNames[i], + int field = _emitIndices ? i : MurmurHash3.murmurhash3(_featureNames[i], _numFields); // +NUM_FIELD to avoid conflict to quantitative features int index = MurmurHash3.murmurhash3(feature, _numFeatures) + _numFields; fv = builder.append(field).append(':').append(index).append(":1").toString(); StringUtils.clear(builder); } else { - if (_emitIndicies) { + if (_emitIndices) { builder.append(i); } else { builder.append(featureName); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/AbstractMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/AbstractMatrix.java b/core/src/main/java/hivemall/math/matrix/AbstractMatrix.java index 2ee27f7..fe3c543 100644 --- a/core/src/main/java/hivemall/math/matrix/AbstractMatrix.java +++ b/core/src/main/java/hivemall/math/matrix/AbstractMatrix.java @@ -102,4 +102,9 @@ public abstract class AbstractMatrix implements Matrix { eachInColumn(col, procedure, false); } + @Override + public void eachNonZeroCell(VectorProcedure procedure) { + throw new UnsupportedOperationException("Not yet supported"); + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/ColumnMajorFloatMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/ColumnMajorFloatMatrix.java b/core/src/main/java/hivemall/math/matrix/ColumnMajorFloatMatrix.java new file mode 100644 index 0000000..6067ed3 --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/ColumnMajorFloatMatrix.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix; + +public abstract class ColumnMajorFloatMatrix extends ColumnMajorMatrix implements FloatMatrix { + + public ColumnMajorFloatMatrix() { + super(); + } + + @Override + public ColumnMajorFloatMatrix toColumnMajorMatrix() { + return this; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/FloatMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/FloatMatrix.java b/core/src/main/java/hivemall/math/matrix/FloatMatrix.java new file mode 100644 index 0000000..f1af65f --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/FloatMatrix.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +public interface FloatMatrix extends Matrix { + + /** + * @throws IndexOutOfBoundsException + */ + public float get(@Nonnegative final int row, @Nonnegative final int col, + final float defaultValue); + + /** + * @throws IndexOutOfBoundsException + * @throws UnsupportedOperationException + */ + public void set(@Nonnegative final int row, @Nonnegative final int col, final float value); + + /** + * @throws IndexOutOfBoundsException + * @throws UnsupportedOperationException + */ + public float getAndSet(@Nonnegative final int row, @Nonnegative final int col, final float value); + + /** + * @return returns dst + */ + @Nonnull + public float[] getRow(@Nonnegative int index, @Nonnull float[] dst); + + @Override + default double get(@Nonnegative final int row, @Nonnegative final int col, + final double defaultValue) { + return get(row, col, (float) defaultValue); + } + + @Override + default void set(@Nonnegative final int row, @Nonnegative final int col, final double value) { + set(row, col, (float) value); + } + + @Override + default double getAndSet(@Nonnegative final int row, @Nonnegative final int col, + final double value) { + return getAndSet(row, col, (float) value); + } + + @Override + public RowMajorFloatMatrix toRowMajorMatrix(); + + @Override + public ColumnMajorFloatMatrix toColumnMajorMatrix(); + +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/Matrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/Matrix.java b/core/src/main/java/hivemall/math/matrix/Matrix.java index 8a4782a..338a4c2 100644 --- a/core/src/main/java/hivemall/math/matrix/Matrix.java +++ b/core/src/main/java/hivemall/math/matrix/Matrix.java @@ -115,6 +115,8 @@ public interface Matrix { public void eachNonZeroInColumn(@Nonnegative int col, @Nonnull VectorProcedure procedure); + public void eachNonZeroCell(@Nonnull final VectorProcedure procedure); + @Nonnull public RowMajorMatrix toRowMajorMatrix(); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/MatrixUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/MatrixUtils.java b/core/src/main/java/hivemall/math/matrix/MatrixUtils.java index 90ce78f..cd137ed 100644 --- a/core/src/main/java/hivemall/math/matrix/MatrixUtils.java +++ b/core/src/main/java/hivemall/math/matrix/MatrixUtils.java @@ -20,10 +20,17 @@ package hivemall.math.matrix; import hivemall.math.matrix.builders.MatrixBuilder; import hivemall.math.matrix.ints.IntMatrix; +import hivemall.math.matrix.sparse.CSCMatrix; +import hivemall.math.matrix.sparse.CSRMatrix; +import hivemall.math.matrix.sparse.floats.CSCFloatMatrix; +import hivemall.math.matrix.sparse.floats.CSRFloatMatrix; import hivemall.math.vector.VectorProcedure; import hivemall.utils.lang.Preconditions; import hivemall.utils.lang.mutable.MutableInt; +import java.util.Arrays; +import java.util.Comparator; + import javax.annotation.Nonnegative; import javax.annotation.Nonnull; @@ -34,7 +41,7 @@ public final class MatrixUtils { @Nonnull public static Matrix shuffle(@Nonnull final Matrix m, @Nonnull final int[] indices) { Preconditions.checkArgument(m.numRows() <= indices.length, "m.numRow() `" + m.numRows() - + "` MUST be equals to or less than |swapIndicies| `" + indices.length + "`"); + + "` MUST be equals to or less than |swapIndices| `" + indices.length + "`"); final MatrixBuilder builder = m.builder(); final VectorProcedure proc = new VectorProcedure() { @@ -70,4 +77,259 @@ public final class MatrixUtils { return which.getValue(); } + /** + * @param data non-zero entries + */ + @Nonnull + public static CSRMatrix coo2csr(@Nonnull final int[] rows, @Nonnull final int[] cols, + @Nonnull final double[] data, @Nonnegative final int numRows, + @Nonnegative final int numCols, final boolean sortColumns) { + final int nnz = data.length; + Preconditions.checkArgument(rows.length == nnz); + Preconditions.checkArgument(cols.length == nnz); + + final int[] rowPointers = new int[numRows + 1]; + final int[] colIndices = new int[nnz]; + final double[] values = new double[nnz]; + + coo2csr(rows, cols, data, rowPointers, colIndices, values, numRows, numCols, nnz); + + if (sortColumns) { + sortIndices(rowPointers, colIndices, values); + } + return new CSRMatrix(rowPointers, colIndices, values, numCols); + } + + /** + * @param data non-zero entries + */ + @Nonnull + public static CSRFloatMatrix coo2csr(@Nonnull final int[] rows, @Nonnull final int[] cols, + @Nonnull final float[] data, @Nonnegative final int numRows, + @Nonnegative final int numCols, final boolean sortColumns) { + final int nnz = data.length; + Preconditions.checkArgument(rows.length == nnz); + Preconditions.checkArgument(cols.length == nnz); + + final int[] rowPointers = new int[numRows + 1]; + final int[] colIndices = new int[nnz]; + final float[] values = new float[nnz]; + + coo2csr(rows, cols, data, rowPointers, colIndices, values, numRows, numCols, nnz); + + if (sortColumns) { + sortIndices(rowPointers, colIndices, values); + } + return new CSRFloatMatrix(rowPointers, colIndices, values, numCols); + } + + @Nonnull + public static CSCMatrix coo2csc(@Nonnull final int[] rows, @Nonnull final int[] cols, + @Nonnull final double[] data, @Nonnegative final int numRows, + @Nonnegative final int numCols, final boolean sortRows) { + final int nnz = data.length; + Preconditions.checkArgument(rows.length == nnz); + Preconditions.checkArgument(cols.length == nnz); + + final int[] columnPointers = new int[numCols + 1]; + final int[] rowIndices = new int[nnz]; + final double[] values = new double[nnz]; + + coo2csr(cols, rows, data, columnPointers, rowIndices, values, numCols, numRows, nnz); + + if (sortRows) { + sortIndices(columnPointers, rowIndices, values); + } + return new CSCMatrix(columnPointers, rowIndices, values, numRows, numCols); + } + + @Nonnull + public static CSCFloatMatrix coo2csc(@Nonnull final int[] rows, @Nonnull final int[] cols, + @Nonnull final float[] data, @Nonnegative final int numRows, + @Nonnegative final int numCols, final boolean sortRows) { + final int nnz = data.length; + Preconditions.checkArgument(rows.length == nnz); + Preconditions.checkArgument(cols.length == nnz); + + final int[] columnPointers = new int[numCols + 1]; + final int[] rowIndices = new int[nnz]; + final float[] values = new float[nnz]; + + coo2csr(cols, rows, data, columnPointers, rowIndices, values, numCols, numRows, nnz); + + if (sortRows) { + sortIndices(columnPointers, rowIndices, values); + } + + return new CSCFloatMatrix(columnPointers, rowIndices, values, numRows, numCols); + } + + private static void coo2csr(@Nonnull final int[] rows, @Nonnull final int[] cols, + @Nonnull final double[] data, @Nonnull final int[] rowPointers, + @Nonnull final int[] colIndices, @Nonnull final double[] values, + @Nonnegative final int numRows, @Nonnegative final int numCols, final int nnz) { + // compute nnz per for each row to get rowPointers + for (int n = 0; n < nnz; n++) { + rowPointers[rows[n]]++; + } + for (int i = 0, sum = 0; i < numRows; i++) { + int curr = rowPointers[i]; + rowPointers[i] = sum; + sum += curr; + } + rowPointers[numRows] = nnz; + + // copy cols, data to colIndices, csrValues + for (int n = 0; n < nnz; n++) { + int row = rows[n]; + int dst = rowPointers[row]; + + colIndices[dst] = cols[n]; + values[dst] = data[n]; + + rowPointers[row]++; + } + + for (int i = 0, last = 0; i <= numRows; i++) { + int tmp = rowPointers[i]; + rowPointers[i] = last; + last = tmp; + } + } + + private static void coo2csr(@Nonnull final int[] rows, @Nonnull final int[] cols, + @Nonnull final float[] data, @Nonnull final int[] rowPointers, + @Nonnull final int[] colIndices, @Nonnull final float[] values, + @Nonnegative final int numRows, @Nonnegative final int numCols, final int nnz) { + // compute nnz per for each row to get rowPointers + for (int n = 0; n < nnz; n++) { + rowPointers[rows[n]]++; + } + for (int i = 0, sum = 0; i < numRows; i++) { + int curr = rowPointers[i]; + rowPointers[i] = sum; + sum += curr; + } + rowPointers[numRows] = nnz; + + // copy cols, data to colIndices, csrValues + for (int n = 0; n < nnz; n++) { + int row = rows[n]; + int dst = rowPointers[row]; + + colIndices[dst] = cols[n]; + values[dst] = data[n]; + + rowPointers[row]++; + } + + for (int i = 0, last = 0; i <= numRows; i++) { + int tmp = rowPointers[i]; + rowPointers[i] = last; + last = tmp; + } + } + + private static void sortIndices(@Nonnull final int[] majorAxisPointers, + @Nonnull final int[] minorAxisIndices, @Nonnull final double[] values) { + final int numRows = majorAxisPointers.length - 1; + if (numRows <= 1) { + return; + } + + for (int i = 0; i < numRows; i++) { + final int rowStart = majorAxisPointers[i]; + final int rowEnd = majorAxisPointers[i + 1]; + + final int numCols = rowEnd - rowStart; + if (numCols == 0) { + continue; + } else if (numCols < 0) { + throw new IllegalArgumentException( + "numCols SHOULD be greater than zero. numCols = rowEnd - rowStart = " + rowEnd + + " - " + rowStart + " = " + numCols + " at i=" + i); + } + + final IntDoublePair[] pairs = new IntDoublePair[numCols]; + for (int jj = rowStart, n = 0; jj < rowEnd; jj++, n++) { + pairs[n] = new IntDoublePair(minorAxisIndices[jj], values[jj]); + } + + Arrays.sort(pairs, new Comparator<IntDoublePair>() { + @Override + public int compare(IntDoublePair x, IntDoublePair y) { + return Integer.compare(x.key, y.key); + } + }); + + for (int jj = rowStart, n = 0; jj < rowEnd; jj++, n++) { + IntDoublePair tmp = pairs[n]; + minorAxisIndices[jj] = tmp.key; + values[jj] = tmp.value; + } + } + } + + private static void sortIndices(@Nonnull final int[] majorAxisPointers, + @Nonnull final int[] minorAxisIndices, @Nonnull final float[] values) { + final int numRows = majorAxisPointers.length - 1; + if (numRows <= 1) { + return; + } + + for (int i = 0; i < numRows; i++) { + final int rowStart = majorAxisPointers[i]; + final int rowEnd = majorAxisPointers[i + 1]; + + final int numCols = rowEnd - rowStart; + if (numCols == 0) { + continue; + } else if (numCols < 0) { + throw new IllegalArgumentException( + "numCols SHOULD be greater than or equal to zero. numCols = rowEnd - rowStart = " + + rowEnd + " - " + rowStart + " = " + numCols + " at i=" + i); + } + + final IntFloatPair[] pairs = new IntFloatPair[numCols]; + for (int jj = rowStart, n = 0; jj < rowEnd; jj++, n++) { + pairs[n] = new IntFloatPair(minorAxisIndices[jj], values[jj]); + } + + Arrays.sort(pairs, new Comparator<IntFloatPair>() { + @Override + public int compare(IntFloatPair x, IntFloatPair y) { + return Integer.compare(x.key, y.key); + } + }); + + for (int jj = rowStart, n = 0; jj < rowEnd; jj++, n++) { + IntFloatPair tmp = pairs[n]; + minorAxisIndices[jj] = tmp.key; + values[jj] = tmp.value; + } + } + } + + private static final class IntDoublePair { + + final int key; + final double value; + + IntDoublePair(int key, double value) { + this.key = key; + this.value = value; + } + } + + private static final class IntFloatPair { + + final int key; + final float value; + + IntFloatPair(int key, float value) { + this.key = key; + this.value = value; + } + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/RowMajorFloatMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/RowMajorFloatMatrix.java b/core/src/main/java/hivemall/math/matrix/RowMajorFloatMatrix.java new file mode 100644 index 0000000..90f7bbf --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/RowMajorFloatMatrix.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix; + +public abstract class RowMajorFloatMatrix extends RowMajorMatrix implements FloatMatrix { + + public RowMajorFloatMatrix() { + super(); + } + + @Override + public RowMajorFloatMatrix toRowMajorMatrix() { + return this; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/builders/CSCMatrixBuilder.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/builders/CSCMatrixBuilder.java b/core/src/main/java/hivemall/math/matrix/builders/CSCMatrixBuilder.java index df2bff7..5c546d5 100644 --- a/core/src/main/java/hivemall/math/matrix/builders/CSCMatrixBuilder.java +++ b/core/src/main/java/hivemall/math/matrix/builders/CSCMatrixBuilder.java @@ -70,19 +70,19 @@ public final class CSCMatrixBuilder extends MatrixBuilder { } final int[] columnIndices = cols.toArray(true); - final int[] rowsIndicies = rows.toArray(true); + final int[] rowsIndices = rows.toArray(true); final double[] valuesArray = values.toArray(true); // convert to column major final int nnz = valuesArray.length; SortObj[] sortObjs = new SortObj[nnz]; for (int i = 0; i < nnz; i++) { - sortObjs[i] = new SortObj(columnIndices[i], rowsIndicies[i], valuesArray[i]); + sortObjs[i] = new SortObj(columnIndices[i], rowsIndices[i], valuesArray[i]); } Arrays.sort(sortObjs); for (int i = 0; i < nnz; i++) { columnIndices[i] = sortObjs[i].columnIndex; - rowsIndicies[i] = sortObjs[i].rowsIndex; + rowsIndices[i] = sortObjs[i].rowsIndex; valuesArray[i] = sortObjs[i].value; } sortObjs = null; @@ -98,7 +98,7 @@ public final class CSCMatrixBuilder extends MatrixBuilder { } columnPointers[maxNumColumns] = nnz; // nnz - return new CSCMatrix(columnPointers, rowsIndicies, valuesArray, row, maxNumColumns); + return new CSCMatrix(columnPointers, rowsIndices, valuesArray, row, maxNumColumns); } private static final class SortObj implements Comparable<SortObj> { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java index f8eb02f..14bb4f9 100644 --- a/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java +++ b/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java @@ -31,7 +31,7 @@ import javax.annotation.Nonnegative; import javax.annotation.Nonnull; /** - * Compressed Sparse Column matrix optimized for colum major access. + * Compressed Sparse Column matrix optimized for column major access. * * @link http://netlib.org/linalg/html_templates/node92.html#SECTION00931200000000000000 */ @@ -40,7 +40,7 @@ public final class CSCMatrix extends ColumnMajorMatrix { @Nonnull private final int[] columnPointers; @Nonnull - private final int[] rowIndicies; + private final int[] rowIndices; @Nonnull private final double[] values; @@ -48,15 +48,15 @@ public final class CSCMatrix extends ColumnMajorMatrix { private final int numColumns; private final int nnz; - public CSCMatrix(@Nonnull int[] columnPointers, @Nonnull int[] rowIndicies, + public CSCMatrix(@Nonnull int[] columnPointers, @Nonnull int[] rowIndices, @Nonnull double[] values, int numRows, int numColumns) { super(); Preconditions.checkArgument(columnPointers.length >= 1, "rowPointers must be greather than 0: " + columnPointers.length); - Preconditions.checkArgument(rowIndicies.length == values.length, "#rowIndicies (" - + rowIndicies.length + ") must be equals to #values (" + values.length + ")"); + Preconditions.checkArgument(rowIndices.length == values.length, "#rowIndices (" + + rowIndices.length + ") must be equals to #values (" + values.length + ")"); this.columnPointers = columnPointers; - this.rowIndicies = rowIndicies; + this.rowIndices = rowIndices; this.values = values; this.numRows = numRows; this.numColumns = numColumns; @@ -97,18 +97,18 @@ public final class CSCMatrix extends ColumnMajorMatrix { public int numColumns(final int row) { checkRowIndex(row, numRows); - return ArrayUtils.count(rowIndicies, row); + return ArrayUtils.count(rowIndices, row); } @Override - public double[] getRow(int index) { + public double[] getRow(final int index) { checkRowIndex(index, numRows); final double[] row = new double[numColumns]; final int numCols = columnPointers.length - 1; for (int j = 0; j < numCols; j++) { - final int k = Arrays.binarySearch(rowIndicies, columnPointers[j], + final int k = Arrays.binarySearch(rowIndices, columnPointers[j], columnPointers[j + 1], index); if (k >= 0) { row[j] = values[k]; @@ -124,12 +124,17 @@ public final class CSCMatrix extends ColumnMajorMatrix { final int last = Math.min(dst.length, columnPointers.length - 1); for (int j = 0; j < last; j++) { - final int k = Arrays.binarySearch(rowIndicies, columnPointers[j], + final int k = Arrays.binarySearch(rowIndices, columnPointers[j], columnPointers[j + 1], index); if (k >= 0) { dst[j] = values[k]; + } else { + dst[j] = 0.d; } } + for (int j = last; j < dst.length; j++) { + dst[j] = 0.d; + } return dst; } @@ -140,7 +145,7 @@ public final class CSCMatrix extends ColumnMajorMatrix { row.clear(); for (int j = 0, last = columnPointers.length - 1; j < last; j++) { - final int k = Arrays.binarySearch(rowIndicies, columnPointers[j], + final int k = Arrays.binarySearch(rowIndices, columnPointers[j], columnPointers[j + 1], index); if (k >= 0) { double v = values[k]; @@ -190,7 +195,7 @@ public final class CSCMatrix extends ColumnMajorMatrix { private int getIndex(@Nonnegative final int row, @Nonnegative final int col) { int leftIn = columnPointers[col]; int rightEx = columnPointers[col + 1]; - final int index = Arrays.binarySearch(rowIndicies, leftIn, rightEx, row); + final int index = Arrays.binarySearch(rowIndices, leftIn, rightEx, row); if (index >= 0 && index >= values.length) { throw new IndexOutOfBoundsException("Value index " + index + " out of range " + values.length); @@ -213,7 +218,7 @@ public final class CSCMatrix extends ColumnMajorMatrix { if (nullOutput) { for (int row = 0, i = startIn; row < numRows; row++) { - if (i < endEx && row == rowIndicies[i]) { + if (i < endEx && row == rowIndices[i]) { double v = values[i++]; procedure.apply(row, v); } else { @@ -222,7 +227,7 @@ public final class CSCMatrix extends ColumnMajorMatrix { } } else { for (int j = startIn; j < endEx; j++) { - int row = rowIndicies[j]; + int row = rowIndices[j]; double v = values[j]; procedure.apply(row, v); } @@ -236,7 +241,7 @@ public final class CSCMatrix extends ColumnMajorMatrix { final int startIn = columnPointers[col]; final int endEx = columnPointers[col + 1]; for (int j = startIn; j < endEx; j++) { - int row = rowIndicies[j]; + int row = rowIndices[j]; final double v = values[j]; if (v != 0.d) { procedure.apply(row, v); @@ -247,12 +252,12 @@ public final class CSCMatrix extends ColumnMajorMatrix { @Override public CSRMatrix toRowMajorMatrix() { final int[] rowPointers = new int[numRows + 1]; - final int[] colIndicies = new int[nnz]; + final int[] colIndices = new int[nnz]; final double[] csrValues = new double[nnz]; // compute nnz per for each row - for (int i = 0; i < rowIndicies.length; i++) { - rowPointers[rowIndicies[i]]++; + for (int i = 0; i < rowIndices.length; i++) { + rowPointers[rowIndices[i]]++; } for (int i = 0, sum = 0; i < numRows; i++) { int curr = rowPointers[i]; @@ -263,10 +268,10 @@ public final class CSCMatrix extends ColumnMajorMatrix { for (int j = 0; j < numColumns; j++) { for (int i = columnPointers[j], last = columnPointers[j + 1]; i < last; i++) { - int col = rowIndicies[i]; + int col = rowIndices[i]; int dst = rowPointers[col]; - colIndicies[dst] = j; + colIndices[dst] = j; csrValues[dst] = values[i]; rowPointers[col]++; @@ -280,7 +285,7 @@ public final class CSCMatrix extends ColumnMajorMatrix { last = tmp; } - return new CSRMatrix(rowPointers, colIndicies, csrValues, numColumns); + return new CSRMatrix(rowPointers, colIndices, csrValues, numColumns); } @Override http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java index 805bbd1..c1fa6e4 100644 --- a/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java +++ b/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java @@ -238,7 +238,7 @@ public final class CSRMatrix extends RowMajorMatrix { @Nonnull public CSCMatrix toColumnMajorMatrix() { final int[] columnPointers = new int[numColumns + 1]; - final int[] rowIndicies = new int[nnz]; + final int[] rowIndices = new int[nnz]; final double[] cscValues = new double[nnz]; // compute nnz per for each column @@ -257,7 +257,7 @@ public final class CSRMatrix extends RowMajorMatrix { int col = columnIndices[j]; int dst = columnPointers[col]; - rowIndicies[dst] = i; + rowIndices[dst] = i; cscValues[dst] = values[j]; columnPointers[col]++; @@ -271,7 +271,7 @@ public final class CSRMatrix extends RowMajorMatrix { last = tmp; } - return new CSCMatrix(columnPointers, rowIndicies, cscValues, numRows, numColumns); + return new CSCMatrix(columnPointers, rowIndices, cscValues, numRows, numColumns); } @Override http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/sparse/DoKFloatMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/sparse/DoKFloatMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/DoKFloatMatrix.java deleted file mode 100644 index 16b4b64..0000000 --- a/core/src/main/java/hivemall/math/matrix/sparse/DoKFloatMatrix.java +++ /dev/null @@ -1,368 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.math.matrix.sparse; - -import hivemall.annotations.Experimental; -import hivemall.math.matrix.AbstractMatrix; -import hivemall.math.matrix.ColumnMajorMatrix; -import hivemall.math.matrix.RowMajorMatrix; -import hivemall.math.matrix.builders.DoKMatrixBuilder; -import hivemall.math.vector.Vector; -import hivemall.math.vector.VectorProcedure; -import hivemall.utils.collections.maps.Long2FloatOpenHashTable; -import hivemall.utils.collections.maps.Long2FloatOpenHashTable.IMapIterator; -import hivemall.utils.lang.Preconditions; -import hivemall.utils.lang.Primitives; - -import javax.annotation.Nonnegative; -import javax.annotation.Nonnull; - -/** - * Dictionary Of Keys based sparse matrix. - * - * This is an efficient structure for constructing a sparse matrix incrementally. - */ -@Experimental -public final class DoKFloatMatrix extends AbstractMatrix { - - @Nonnull - private final Long2FloatOpenHashTable elements; - @Nonnegative - private int numRows; - @Nonnegative - private int numColumns; - @Nonnegative - private int nnz; - - public DoKFloatMatrix() { - this(0, 0); - } - - public DoKFloatMatrix(@Nonnegative int numRows, @Nonnegative int numCols) { - this(numRows, numCols, 0.05f); - } - - public DoKFloatMatrix(@Nonnegative int numRows, @Nonnegative int numCols, - @Nonnegative float sparsity) { - super(); - Preconditions.checkArgument(sparsity >= 0.f && sparsity <= 1.f, "Invalid Sparsity value: " - + sparsity); - int initialCapacity = Math.max(16384, Math.round(numRows * numCols * sparsity)); - this.elements = new Long2FloatOpenHashTable(initialCapacity); - elements.defaultReturnValue(0.f); - this.numRows = numRows; - this.numColumns = numCols; - this.nnz = 0; - } - - public DoKFloatMatrix(@Nonnegative int initSize) { - super(); - int initialCapacity = Math.max(initSize, 16384); - this.elements = new Long2FloatOpenHashTable(initialCapacity); - elements.defaultReturnValue(0.f); - this.numRows = 0; - this.numColumns = 0; - this.nnz = 0; - } - - @Override - public boolean isSparse() { - return true; - } - - @Override - public boolean isRowMajorMatrix() { - return false; - } - - @Override - public boolean isColumnMajorMatrix() { - return false; - } - - @Override - public boolean readOnly() { - return false; - } - - @Override - public boolean swappable() { - return true; - } - - @Override - public int nnz() { - return nnz; - } - - @Override - public int numRows() { - return numRows; - } - - @Override - public int numColumns() { - return numColumns; - } - - @Override - public int numColumns(@Nonnegative final int row) { - int count = 0; - for (int j = 0; j < numColumns; j++) { - long index = index(row, j); - if (elements.containsKey(index)) { - count++; - } - } - return count; - } - - @Override - public double[] getRow(@Nonnegative final int index) { - double[] dst = row(); - return getRow(index, dst); - } - - @Override - public double[] getRow(@Nonnegative final int row, @Nonnull final double[] dst) { - checkRowIndex(row, numRows); - - final int end = Math.min(dst.length, numColumns); - for (int col = 0; col < end; col++) { - long k = index(row, col); - float v = elements.get(k); - dst[col] = v; - } - - return dst; - } - - @Override - public void getRow(@Nonnegative final int index, @Nonnull final Vector row) { - checkRowIndex(index, numRows); - row.clear(); - - for (int col = 0; col < numColumns; col++) { - long k = index(index, col); - final float v = elements.get(k, 0.f); - if (v != 0.f) { - row.set(col, v); - } - } - } - - @Override - public double get(@Nonnegative final int row, @Nonnegative final int col, - final double defaultValue) { - return get(row, col, (float) defaultValue); - } - - public float get(@Nonnegative final int row, @Nonnegative final int col, - final float defaultValue) { - long index = index(row, col); - return elements.get(index, defaultValue); - } - - @Override - public void set(@Nonnegative final int row, @Nonnegative final int col, final double value) { - set(row, col, (float) value); - } - - public void set(@Nonnegative final int row, @Nonnegative final int col, final float value) { - checkIndex(row, col); - - final long index = index(row, col); - if (value == 0.f && elements.containsKey(index) == false) { - return; - } - - if (elements.put(index, value, 0.f) == 0.f) { - nnz++; - this.numRows = Math.max(numRows, row + 1); - this.numColumns = Math.max(numColumns, col + 1); - } - } - - @Override - public double getAndSet(@Nonnegative final int row, @Nonnegative final int col, - final double value) { - return getAndSet(row, col, (float) value); - } - - public float getAndSet(@Nonnegative final int row, @Nonnegative final int col, final float value) { - checkIndex(row, col); - - final long index = index(row, col); - if (value == 0.f && elements.containsKey(index) == false) { - return 0.f; - } - - final float old = elements.put(index, value, 0.f); - if (old == 0.f) { - nnz++; - this.numRows = Math.max(numRows, row + 1); - this.numColumns = Math.max(numColumns, col + 1); - } - return old; - } - - @Override - public void swap(@Nonnegative final int row1, @Nonnegative final int row2) { - checkRowIndex(row1, numRows); - checkRowIndex(row2, numRows); - - for (int j = 0; j < numColumns; j++) { - final long i1 = index(row1, j); - final long i2 = index(row2, j); - - final int k1 = elements._findKey(i1); - final int k2 = elements._findKey(i2); - - if (k1 >= 0) { - if (k2 >= 0) { - float v1 = elements._get(k1); - float v2 = elements._set(k2, v1); - elements._set(k1, v2); - } else {// k1>=0 and k2<0 - float v1 = elements._remove(k1); - elements.put(i2, v1); - } - } else if (k2 >= 0) {// k2>=0 and k1 < 0 - float v2 = elements._remove(k2); - elements.put(i1, v2); - } else {//k1<0 and k2<0 - continue; - } - } - } - - @Override - public void eachInRow(@Nonnegative final int row, @Nonnull final VectorProcedure procedure, - final boolean nullOutput) { - checkRowIndex(row, numRows); - - for (int col = 0; col < numColumns; col++) { - long i = index(row, col); - final int key = elements._findKey(i); - if (key < 0) { - if (nullOutput) { - procedure.apply(col, 0.d); - } - } else { - float v = elements._get(key); - procedure.apply(col, v); - } - } - } - - @Override - public void eachNonZeroInRow(@Nonnegative final int row, - @Nonnull final VectorProcedure procedure) { - checkRowIndex(row, numRows); - - for (int col = 0; col < numColumns; col++) { - long i = index(row, col); - final float v = elements.get(i, 0.f); - if (v != 0.f) { - procedure.apply(col, v); - } - } - } - - @Override - public void eachColumnIndexInRow(int row, VectorProcedure procedure) { - checkRowIndex(row, numRows); - - for (int col = 0; col < numColumns; col++) { - long i = index(row, col); - final int key = elements._findKey(i); - if (key != -1) { - procedure.apply(col); - } - } - } - - @Override - public void eachInColumn(@Nonnegative final int col, @Nonnull final VectorProcedure procedure, - final boolean nullOutput) { - checkColIndex(col, numColumns); - - for (int row = 0; row < numRows; row++) { - long i = index(row, col); - final int key = elements._findKey(i); - if (key < 0) { - if (nullOutput) { - procedure.apply(row, 0.d); - } - } else { - float v = elements._get(key); - procedure.apply(row, v); - } - } - } - - @Override - public void eachNonZeroInColumn(@Nonnegative final int col, - @Nonnull final VectorProcedure procedure) { - checkColIndex(col, numColumns); - - for (int row = 0; row < numRows; row++) { - long i = index(row, col); - final float v = elements.get(i, 0.f); - if (v != 0.f) { - procedure.apply(row, v); - } - } - } - - public void eachNonZeroCell(@Nonnull final VectorProcedure procedure) { - if (nnz == 0) { - return; - } - final IMapIterator itor = elements.entries(); - while (itor.next() != -1) { - long k = itor.getKey(); - int row = Primitives.getHigh(k); - int col = Primitives.getLow(k); - float value = itor.getValue(); - procedure.apply(row, col, value); - } - } - - @Override - public RowMajorMatrix toRowMajorMatrix() { - throw new UnsupportedOperationException("Not yet supported"); - } - - @Override - public ColumnMajorMatrix toColumnMajorMatrix() { - throw new UnsupportedOperationException("Not yet supported"); - } - - @Override - public DoKMatrixBuilder builder() { - return new DoKMatrixBuilder(elements.size()); - } - - @Nonnegative - private static long index(@Nonnegative final int row, @Nonnegative final int col) { - return Primitives.toLong(row, col); - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java index 054d62a..6dc0502 100644 --- a/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java +++ b/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java @@ -21,6 +21,7 @@ package hivemall.math.matrix.sparse; import hivemall.annotations.Experimental; import hivemall.math.matrix.AbstractMatrix; import hivemall.math.matrix.ColumnMajorMatrix; +import hivemall.math.matrix.MatrixUtils; import hivemall.math.matrix.RowMajorMatrix; import hivemall.math.matrix.builders.DoKMatrixBuilder; import hivemall.math.vector.Vector; @@ -333,12 +334,44 @@ public final class DoKMatrix extends AbstractMatrix { @Override public RowMajorMatrix toRowMajorMatrix() { - throw new UnsupportedOperationException("Not yet supported"); + final int nnz = elements.size(); + final int[] rows = new int[nnz]; + final int[] cols = new int[nnz]; + final double[] data = new double[nnz]; + + final IMapIterator itor = elements.entries(); + for (int i = 0; i < nnz; i++) { + if (itor.next() == -1) { + throw new IllegalStateException("itor.next() returns -1 where i=" + i); + } + long k = itor.getKey(); + rows[i] = Primitives.getHigh(k); + cols[i] = Primitives.getLow(k); + data[i] = itor.getValue(); + } + + return MatrixUtils.coo2csr(rows, cols, data, numRows, numColumns, true); } @Override public ColumnMajorMatrix toColumnMajorMatrix() { - throw new UnsupportedOperationException("Not yet supported"); + final int nnz = elements.size(); + final int[] rows = new int[nnz]; + final int[] cols = new int[nnz]; + final double[] data = new double[nnz]; + + final IMapIterator itor = elements.entries(); + for (int i = 0; i < nnz; i++) { + if (itor.next() == -1) { + throw new IllegalStateException("itor.next() returns -1 where i=" + i); + } + long k = itor.getKey(); + rows[i] = Primitives.getHigh(k); + cols[i] = Primitives.getLow(k); + data[i] = itor.getValue(); + } + + return MatrixUtils.coo2csc(rows, cols, data, numRows, numColumns, true); } @Override http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/sparse/floats/CSCFloatMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/sparse/floats/CSCFloatMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/floats/CSCFloatMatrix.java new file mode 100644 index 0000000..3aa1dc9 --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/sparse/floats/CSCFloatMatrix.java @@ -0,0 +1,317 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix.sparse.floats; + +import hivemall.math.matrix.ColumnMajorFloatMatrix; +import hivemall.math.matrix.builders.CSCMatrixBuilder; +import hivemall.math.vector.Vector; +import hivemall.math.vector.VectorProcedure; +import hivemall.utils.lang.ArrayUtils; +import hivemall.utils.lang.Preconditions; + +import java.util.Arrays; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +/** + * Compressed Sparse Column matrix optimized for column major access. + * + * @link http://netlib.org/linalg/html_templates/node92.html#SECTION00931200000000000000 + */ +public final class CSCFloatMatrix extends ColumnMajorFloatMatrix { + + @Nonnull + private final int[] columnPointers; + @Nonnull + private final int[] rowIndices; + @Nonnull + private final float[] values; + + private final int numRows; + private final int numColumns; + private final int nnz; + + public CSCFloatMatrix(@Nonnull int[] columnPointers, @Nonnull int[] rowIndices, + @Nonnull float[] values, int numRows, int numColumns) { + super(); + Preconditions.checkArgument(columnPointers.length >= 1, + "rowPointers must be greather than 0: " + columnPointers.length); + Preconditions.checkArgument(rowIndices.length == values.length, "#rowIndices (" + + rowIndices.length + ") must be equals to #values (" + values.length + ")"); + this.columnPointers = columnPointers; + this.rowIndices = rowIndices; + this.values = values; + this.numRows = numRows; + this.numColumns = numColumns; + this.nnz = values.length; + } + + @Override + public boolean isSparse() { + return true; + } + + @Override + public boolean readOnly() { + return true; + } + + @Override + public boolean swappable() { + return false; + } + + @Override + public int nnz() { + return nnz; + } + + @Override + public int numRows() { + return numRows; + } + + @Override + public int numColumns() { + return numColumns; + } + + @Override + public int numColumns(final int row) { + checkRowIndex(row, numRows); + + return ArrayUtils.count(rowIndices, row); + } + + @Override + public double[] getRow(final int index) { + checkRowIndex(index, numRows); + + final double[] row = new double[numColumns]; + + final int numCols = columnPointers.length - 1; + for (int j = 0; j < numCols; j++) { + final int k = Arrays.binarySearch(rowIndices, columnPointers[j], + columnPointers[j + 1], index); + if (k >= 0) { + row[j] = values[k]; + } + } + + return row; + } + + @Override + public double[] getRow(final int index, @Nonnull final double[] dst) { + checkRowIndex(index, numRows); + + final int last = Math.min(dst.length, columnPointers.length - 1); + for (int j = 0; j < last; j++) { + final int k = Arrays.binarySearch(rowIndices, columnPointers[j], + columnPointers[j + 1], index); + if (k >= 0) { + dst[j] = values[k]; + } else { + dst[j] = 0.d; + } + } + for (int j = last; j < dst.length; j++) { + dst[j] = 0.d; + } + + return dst; + } + + @Override + public float[] getRow(final int index, @Nonnull final float[] dst) { + checkRowIndex(index, numRows); + + final int last = Math.min(dst.length, columnPointers.length - 1); + for (int j = 0; j < last; j++) { + final int k = Arrays.binarySearch(rowIndices, columnPointers[j], + columnPointers[j + 1], index); + if (k >= 0) { + dst[j] = values[k]; + } else { + dst[j] = 0.f; + } + } + for (int j = last; j < dst.length; j++) { + dst[j] = 0.f; + } + + return dst; + } + + @Override + public void getRow(final int index, @Nonnull final Vector row) { + checkRowIndex(index, numRows); + row.clear(); + + for (int j = 0, last = columnPointers.length - 1; j < last; j++) { + final int k = Arrays.binarySearch(rowIndices, columnPointers[j], + columnPointers[j + 1], index); + if (k >= 0) { + float v = values[k]; + row.set(j, v); + } + } + } + + @Override + public float get(final int row, final int col, final float defaultValue) { + checkIndex(row, col, numRows, numColumns); + + int index = getIndex(row, col); + if (index < 0) { + return defaultValue; + } + return values[index]; + } + + @Override + public float getAndSet(final int row, final int col, final float value) { + checkIndex(row, col, numRows, numColumns); + + final int index = getIndex(row, col); + if (index < 0) { + throw new UnsupportedOperationException("Cannot update value in row " + row + ", col " + + col); + } + + float old = values[index]; + values[index] = value; + return old; + } + + @Override + public void set(final int row, final int col, final float value) { + checkIndex(row, col, numRows, numColumns); + + final int index = getIndex(row, col); + if (index < 0) { + throw new UnsupportedOperationException("Cannot update value in row " + row + ", col " + + col); + } + values[index] = value; + } + + private int getIndex(@Nonnegative final int row, @Nonnegative final int col) { + int leftIn = columnPointers[col]; + int rightEx = columnPointers[col + 1]; + final int index = Arrays.binarySearch(rowIndices, leftIn, rightEx, row); + if (index >= 0 && index >= values.length) { + throw new IndexOutOfBoundsException("Value index " + index + " out of range " + + values.length); + } + return index; + } + + @Override + public void swap(final int row1, final int row2) { + throw new UnsupportedOperationException(); + } + + @Override + public void eachInColumn(final int col, @Nonnull final VectorProcedure procedure, + final boolean nullOutput) { + checkColIndex(col, numColumns); + + final int startIn = columnPointers[col]; + final int endEx = columnPointers[col + 1]; + + if (nullOutput) { + for (int row = 0, i = startIn; row < numRows; row++) { + if (i < endEx && row == rowIndices[i]) { + float v = values[i++]; + procedure.apply(row, v); + } else { + procedure.apply(row, 0.f); + } + } + } else { + for (int j = startIn; j < endEx; j++) { + int row = rowIndices[j]; + float v = values[j]; + procedure.apply(row, v); + } + } + } + + @Override + public void eachNonZeroInColumn(final int col, @Nonnull final VectorProcedure procedure) { + checkColIndex(col, numColumns); + + final int startIn = columnPointers[col]; + final int endEx = columnPointers[col + 1]; + for (int j = startIn; j < endEx; j++) { + int row = rowIndices[j]; + final float v = values[j]; + if (v != 0.f) { + procedure.apply(row, v); + } + } + } + + @Override + public CSRFloatMatrix toRowMajorMatrix() { + final int[] rowPointers = new int[numRows + 1]; + final int[] colIndices = new int[nnz]; + final float[] csrValues = new float[nnz]; + + // compute nnz per for each row + for (int i = 0; i < rowIndices.length; i++) { + rowPointers[rowIndices[i]]++; + } + for (int i = 0, sum = 0; i < numRows; i++) { + int curr = rowPointers[i]; + rowPointers[i] = sum; + sum += curr; + } + rowPointers[numRows] = nnz; + + for (int j = 0; j < numColumns; j++) { + for (int i = columnPointers[j], last = columnPointers[j + 1]; i < last; i++) { + int col = rowIndices[i]; + int dst = rowPointers[col]; + + colIndices[dst] = j; + csrValues[dst] = values[i]; + + rowPointers[col]++; + } + } + + // shift column pointers + for (int i = 0, last = 0; i <= numRows; i++) { + int tmp = rowPointers[i]; + rowPointers[i] = last; + last = tmp; + } + + return new CSRFloatMatrix(rowPointers, colIndices, csrValues, numColumns); + } + + @Override + public CSCMatrixBuilder builder() { + return new CSCMatrixBuilder(nnz); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/sparse/floats/CSRFloatMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/sparse/floats/CSRFloatMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/floats/CSRFloatMatrix.java new file mode 100644 index 0000000..3dd44de --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/sparse/floats/CSRFloatMatrix.java @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix.sparse.floats; + +import hivemall.math.matrix.RowMajorFloatMatrix; +import hivemall.math.matrix.builders.CSRMatrixBuilder; +import hivemall.math.vector.VectorProcedure; +import hivemall.utils.lang.Preconditions; + +import java.util.Arrays; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +/** + * Compressed Sparse Row Matrix optimized for row major access. + * + * @link http://netlib.org/linalg/html_templates/node91.html#SECTION00931100000000000000 + * @link http://www.cs.colostate.edu/~mcrob/toolbox/c++/sparseMatrix/sparse_matrix_compression.html + */ +public final class CSRFloatMatrix extends RowMajorFloatMatrix { + + @Nonnull + private final int[] rowPointers; + @Nonnull + private final int[] columnIndices; + @Nonnull + private final float[] values; + + @Nonnegative + private final int numRows; + @Nonnegative + private final int numColumns; + @Nonnegative + private final int nnz; + + public CSRFloatMatrix(@Nonnull int[] rowPointers, @Nonnull int[] columnIndices, + @Nonnull float[] values, @Nonnegative int numColumns) { + super(); + Preconditions.checkArgument(rowPointers.length >= 1, + "rowPointers must be greather than 0: " + rowPointers.length); + Preconditions.checkArgument(columnIndices.length == values.length, "#columnIndices (" + + columnIndices.length + ") must be equals to #values (" + values.length + ")"); + this.rowPointers = rowPointers; + this.columnIndices = columnIndices; + this.values = values; + this.numRows = rowPointers.length - 1; + this.numColumns = numColumns; + this.nnz = values.length; + } + + @Override + public boolean isSparse() { + return true; + } + + @Override + public boolean readOnly() { + return true; + } + + @Override + public boolean swappable() { + return false; + } + + @Override + public int nnz() { + return nnz; + } + + @Override + public int numRows() { + return numRows; + } + + @Override + public int numColumns() { + return numColumns; + } + + @Override + public int numColumns(@Nonnegative final int row) { + checkRowIndex(row, numRows); + + int columns = rowPointers[row + 1] - rowPointers[row]; + return columns; + } + + @Override + public double[] getRow(@Nonnegative final int index) { + final double[] row = new double[numColumns]; + eachNonZeroInRow(index, new VectorProcedure() { + public void apply(int col, float value) { + row[col] = value; + } + }); + return row; + } + + @Override + public double[] getRow(@Nonnegative final int index, @Nonnull final double[] dst) { + Arrays.fill(dst, 0.d); + eachNonZeroInRow(index, new VectorProcedure() { + public void apply(int col, float value) { + checkColIndex(col, numColumns); + dst[col] = value; + } + }); + return dst; + } + + @Override + public float[] getRow(@Nonnegative final int index, @Nonnull final float[] dst) { + Arrays.fill(dst, 0.f); + eachNonZeroInRow(index, new VectorProcedure() { + public void apply(int col, float value) { + checkColIndex(col, numColumns); + dst[col] = value; + } + }); + return dst; + } + + @Override + public float get(@Nonnegative final int row, @Nonnegative final int col, + final float defaultValue) { + checkIndex(row, col, numRows, numColumns); + + final int index = getIndex(row, col); + if (index < 0) { + return defaultValue; + } + return values[index]; + } + + @Override + public float getAndSet(@Nonnegative final int row, @Nonnegative final int col, final float value) { + checkIndex(row, col, numRows, numColumns); + + final int index = getIndex(row, col); + if (index < 0) { + throw new UnsupportedOperationException("Cannot update value in row " + row + ", col " + + col); + } + + float old = values[index]; + values[index] = value; + return old; + } + + @Override + public void set(@Nonnegative final int row, @Nonnegative final int col, final float value) { + checkIndex(row, col, numRows, numColumns); + + final int index = getIndex(row, col); + if (index < 0) { + throw new UnsupportedOperationException("Cannot update value in row " + row + ", col " + + col); + } + values[index] = value; + } + + private int getIndex(@Nonnegative final int row, @Nonnegative final int col) { + int leftIn = rowPointers[row]; + int rightEx = rowPointers[row + 1]; + final int index = Arrays.binarySearch(columnIndices, leftIn, rightEx, col); + if (index >= 0 && index >= values.length) { + throw new IndexOutOfBoundsException("Value index " + index + " out of range " + + values.length); + } + return index; + } + + @Override + public void swap(int row1, int row2) { + throw new UnsupportedOperationException(); + } + + @Override + public void eachInRow(@Nonnegative final int row, @Nonnull final VectorProcedure procedure, + final boolean nullOutput) { + checkRowIndex(row, numRows); + + final int startIn = rowPointers[row]; + final int endEx = rowPointers[row + 1]; + + if (nullOutput) { + for (int col = 0, j = startIn; col < numColumns; col++) { + if (j < endEx && col == columnIndices[j]) { + float v = values[j++]; + procedure.apply(col, v); + } else { + procedure.apply(col, 0.f); + } + } + } else { + for (int i = startIn; i < endEx; i++) { + procedure.apply(columnIndices[i], values[i]); + } + } + } + + @Override + public void eachNonZeroInRow(@Nonnegative final int row, + @Nonnull final VectorProcedure procedure) { + checkRowIndex(row, numRows); + + final int startIn = rowPointers[row]; + final int endEx = rowPointers[row + 1]; + for (int i = startIn; i < endEx; i++) { + int col = columnIndices[i]; + final float v = values[i]; + if (v != 0.f) { + procedure.apply(col, v); + } + } + } + + @Override + public void eachColumnIndexInRow(@Nonnegative final int row, + @Nonnull final VectorProcedure procedure) { + checkRowIndex(row, numRows); + + final int startIn = rowPointers[row]; + final int endEx = rowPointers[row + 1]; + + for (int i = startIn; i < endEx; i++) { + procedure.apply(columnIndices[i]); + } + } + + @Nonnull + public CSCFloatMatrix toColumnMajorMatrix() { + final int[] columnPointers = new int[numColumns + 1]; + final int[] rowIndices = new int[nnz]; + final float[] cscValues = new float[nnz]; + + // compute nnz per for each column + for (int j = 0; j < columnIndices.length; j++) { + columnPointers[columnIndices[j]]++; + } + for (int j = 0, sum = 0; j < numColumns; j++) { + int curr = columnPointers[j]; + columnPointers[j] = sum; + sum += curr; + } + columnPointers[numColumns] = nnz; + + for (int i = 0; i < numRows; i++) { + for (int j = rowPointers[i], last = rowPointers[i + 1]; j < last; j++) { + int col = columnIndices[j]; + int dst = columnPointers[col]; + + rowIndices[dst] = i; + cscValues[dst] = values[j]; + + columnPointers[col]++; + } + } + + // shift column pointers + for (int j = 0, last = 0; j <= numColumns; j++) { + int tmp = columnPointers[j]; + columnPointers[j] = last; + last = tmp; + } + + return new CSCFloatMatrix(columnPointers, rowIndices, cscValues, numRows, numColumns); + } + + @Override + public CSRMatrixBuilder builder() { + return new CSRMatrixBuilder(values.length); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/matrix/sparse/floats/DoKFloatMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/sparse/floats/DoKFloatMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/floats/DoKFloatMatrix.java new file mode 100644 index 0000000..10929fb --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/sparse/floats/DoKFloatMatrix.java @@ -0,0 +1,401 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix.sparse.floats; + +import hivemall.annotations.Experimental; +import hivemall.math.matrix.AbstractMatrix; +import hivemall.math.matrix.FloatMatrix; +import hivemall.math.matrix.MatrixUtils; +import hivemall.math.matrix.builders.DoKMatrixBuilder; +import hivemall.math.vector.Vector; +import hivemall.math.vector.VectorProcedure; +import hivemall.utils.collections.maps.Long2FloatOpenHashTable; +import hivemall.utils.collections.maps.Long2FloatOpenHashTable.IMapIterator; +import hivemall.utils.lang.Preconditions; +import hivemall.utils.lang.Primitives; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +/** + * Dictionary Of Keys based sparse matrix. + * + * This is an efficient structure for constructing a sparse matrix incrementally. + */ +@Experimental +public final class DoKFloatMatrix extends AbstractMatrix implements FloatMatrix { + + @Nonnull + private final Long2FloatOpenHashTable elements; + @Nonnegative + private int numRows; + @Nonnegative + private int numColumns; + @Nonnegative + private int nnz; + + public DoKFloatMatrix() { + this(0, 0); + } + + public DoKFloatMatrix(@Nonnegative int numRows, @Nonnegative int numCols) { + this(numRows, numCols, 0.05f); + } + + public DoKFloatMatrix(@Nonnegative int numRows, @Nonnegative int numCols, + @Nonnegative float sparsity) { + super(); + Preconditions.checkArgument(sparsity >= 0.f && sparsity <= 1.f, "Invalid Sparsity value: " + + sparsity); + int initialCapacity = Math.max(16384, Math.round(numRows * numCols * sparsity)); + this.elements = new Long2FloatOpenHashTable(initialCapacity); + elements.defaultReturnValue(0.f); + this.numRows = numRows; + this.numColumns = numCols; + this.nnz = 0; + } + + public DoKFloatMatrix(@Nonnegative int initSize) { + super(); + int initialCapacity = Math.max(initSize, 16384); + this.elements = new Long2FloatOpenHashTable(initialCapacity); + elements.defaultReturnValue(0.f); + this.numRows = 0; + this.numColumns = 0; + this.nnz = 0; + } + + @Override + public boolean isSparse() { + return true; + } + + @Override + public boolean isRowMajorMatrix() { + return false; + } + + @Override + public boolean isColumnMajorMatrix() { + return false; + } + + @Override + public boolean readOnly() { + return false; + } + + @Override + public boolean swappable() { + return true; + } + + @Override + public int nnz() { + return nnz; + } + + @Override + public int numRows() { + return numRows; + } + + @Override + public int numColumns() { + return numColumns; + } + + @Override + public int numColumns(@Nonnegative final int row) { + int count = 0; + for (int j = 0; j < numColumns; j++) { + long index = index(row, j); + if (elements.containsKey(index)) { + count++; + } + } + return count; + } + + @Override + public double[] getRow(@Nonnegative final int index) { + double[] dst = row(); + return getRow(index, dst); + } + + @Override + public double[] getRow(@Nonnegative final int row, @Nonnull final double[] dst) { + checkRowIndex(row, numRows); + + final int end = Math.min(dst.length, numColumns); + for (int col = 0; col < end; col++) { + long k = index(row, col); + float v = elements.get(k); + dst[col] = v; + } + + return dst; + } + + @Override + public float[] getRow(@Nonnegative final int row, @Nonnull final float[] dst) { + checkRowIndex(row, numRows); + + final int end = Math.min(dst.length, numColumns); + for (int col = 0; col < end; col++) { + long k = index(row, col); + float v = elements.get(k); + dst[col] = v; + } + + return dst; + } + + @Override + public void getRow(@Nonnegative final int index, @Nonnull final Vector row) { + checkRowIndex(index, numRows); + row.clear(); + + for (int col = 0; col < numColumns; col++) { + long k = index(index, col); + final float v = elements.get(k, 0.f); + if (v != 0.f) { + row.set(col, v); + } + } + } + + @Override + public float get(@Nonnegative final int row, @Nonnegative final int col, + final float defaultValue) { + long index = index(row, col); + return elements.get(index, defaultValue); + } + + @Override + public void set(@Nonnegative final int row, @Nonnegative final int col, final float value) { + checkIndex(row, col); + + final long index = index(row, col); + if (value == 0.f && elements.containsKey(index) == false) { + return; + } + + if (elements.put(index, value, 0.f) == 0.f) { + nnz++; + this.numRows = Math.max(numRows, row + 1); + this.numColumns = Math.max(numColumns, col + 1); + } + } + + @Override + public float getAndSet(@Nonnegative final int row, @Nonnegative final int col, final float value) { + checkIndex(row, col); + + final long index = index(row, col); + if (value == 0.f && elements.containsKey(index) == false) { + return 0.f; + } + + final float old = elements.put(index, value, 0.f); + if (old == 0.f) { + nnz++; + this.numRows = Math.max(numRows, row + 1); + this.numColumns = Math.max(numColumns, col + 1); + } + return old; + } + + @Override + public void swap(@Nonnegative final int row1, @Nonnegative final int row2) { + checkRowIndex(row1, numRows); + checkRowIndex(row2, numRows); + + for (int j = 0; j < numColumns; j++) { + final long i1 = index(row1, j); + final long i2 = index(row2, j); + + final int k1 = elements._findKey(i1); + final int k2 = elements._findKey(i2); + + if (k1 >= 0) { + if (k2 >= 0) { + float v1 = elements._get(k1); + float v2 = elements._set(k2, v1); + elements._set(k1, v2); + } else {// k1>=0 and k2<0 + float v1 = elements._remove(k1); + elements.put(i2, v1); + } + } else if (k2 >= 0) {// k2>=0 and k1 < 0 + float v2 = elements._remove(k2); + elements.put(i1, v2); + } else {//k1<0 and k2<0 + continue; + } + } + } + + @Override + public void eachInRow(@Nonnegative final int row, @Nonnull final VectorProcedure procedure, + final boolean nullOutput) { + checkRowIndex(row, numRows); + + for (int col = 0; col < numColumns; col++) { + long i = index(row, col); + final int key = elements._findKey(i); + if (key < 0) { + if (nullOutput) { + procedure.apply(col, 0.f); + } + } else { + float v = elements._get(key); + procedure.apply(col, v); + } + } + } + + @Override + public void eachNonZeroInRow(@Nonnegative final int row, + @Nonnull final VectorProcedure procedure) { + checkRowIndex(row, numRows); + + for (int col = 0; col < numColumns; col++) { + long i = index(row, col); + final float v = elements.get(i, 0.f); + if (v != 0.f) { + procedure.apply(col, v); + } + } + } + + @Override + public void eachColumnIndexInRow(int row, VectorProcedure procedure) { + checkRowIndex(row, numRows); + + for (int col = 0; col < numColumns; col++) { + long i = index(row, col); + final int key = elements._findKey(i); + if (key != -1) { + procedure.apply(col); + } + } + } + + @Override + public void eachInColumn(@Nonnegative final int col, @Nonnull final VectorProcedure procedure, + final boolean nullOutput) { + checkColIndex(col, numColumns); + + for (int row = 0; row < numRows; row++) { + long i = index(row, col); + final int key = elements._findKey(i); + if (key < 0) { + if (nullOutput) { + procedure.apply(row, 0.f); + } + } else { + float v = elements._get(key); + procedure.apply(row, v); + } + } + } + + @Override + public void eachNonZeroInColumn(@Nonnegative final int col, + @Nonnull final VectorProcedure procedure) { + checkColIndex(col, numColumns); + + for (int row = 0; row < numRows; row++) { + long i = index(row, col); + final float v = elements.get(i, 0.f); + if (v != 0.f) { + procedure.apply(row, v); + } + } + } + + @Override + public void eachNonZeroCell(@Nonnull final VectorProcedure procedure) { + if (nnz == 0) { + return; + } + final IMapIterator itor = elements.entries(); + while (itor.next() != -1) { + long k = itor.getKey(); + int row = Primitives.getHigh(k); + int col = Primitives.getLow(k); + float value = itor.getValue(); + procedure.apply(row, col, value); + } + } + + @Override + public CSRFloatMatrix toRowMajorMatrix() { + final int nnz = elements.size(); + final int[] rows = new int[nnz]; + final int[] cols = new int[nnz]; + final float[] data = new float[nnz]; + + final IMapIterator itor = elements.entries(); + for (int i = 0; i < nnz; i++) { + if (itor.next() == -1) { + throw new IllegalStateException("itor.next() returns -1 where i=" + i); + } + long k = itor.getKey(); + rows[i] = Primitives.getHigh(k); + cols[i] = Primitives.getLow(k); + data[i] = itor.getValue(); + } + + return MatrixUtils.coo2csr(rows, cols, data, numRows, numColumns, true); + } + + @Override + public CSCFloatMatrix toColumnMajorMatrix() { + final int nnz = elements.size(); + final int[] rows = new int[nnz]; + final int[] cols = new int[nnz]; + final float[] data = new float[nnz]; + + final IMapIterator itor = elements.entries(); + for (int i = 0; i < nnz; i++) { + if (itor.next() == -1) { + throw new IllegalStateException("itor.next() returns -1 where i=" + i); + } + long k = itor.getKey(); + rows[i] = Primitives.getHigh(k); + cols[i] = Primitives.getLow(k); + data[i] = itor.getValue(); + } + + return MatrixUtils.coo2csc(rows, cols, data, numRows, numColumns, true); + } + + @Override + public DoKMatrixBuilder builder() { + return new DoKMatrixBuilder(elements.size()); + } + + @Nonnegative + private static long index(@Nonnegative final int row, @Nonnegative final int col) { + return Primitives.toLong(row, col); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fdf70214/core/src/main/java/hivemall/math/vector/AbstractVector.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/vector/AbstractVector.java b/core/src/main/java/hivemall/math/vector/AbstractVector.java index 88bed7b..7c4579f 100644 --- a/core/src/main/java/hivemall/math/vector/AbstractVector.java +++ b/core/src/main/java/hivemall/math/vector/AbstractVector.java @@ -29,6 +29,16 @@ public abstract class AbstractVector implements Vector { return get(index, 0.d); } + @Override + public float get(@Nonnegative final int index, final float defaultValue) { + return (float) get(index, (double) defaultValue); + } + + @Override + public void set(@Nonnegative final int index, final float value) { + set(index, (double) value); + } + protected static final void checkIndex(final int index) { if (index < 0) { throw new IndexOutOfBoundsException("Invalid index " + index);
