http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorDenseIntMatrix2d.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorDenseIntMatrix2d.java b/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorDenseIntMatrix2d.java new file mode 100644 index 0000000..d028d47 --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorDenseIntMatrix2d.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix.ints; + +import hivemall.math.vector.VectorProcedure; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +public final class ColumnMajorDenseIntMatrix2d extends ColumnMajorIntMatrix { + + @Nonnull + private final int[][] data; // col-row + + @Nonnegative + private final int numRows; + @Nonnegative + private final int numColumns; + + public ColumnMajorDenseIntMatrix2d(@Nonnull int[][] data, @Nonnegative int numRows) { + super(); + this.data = data; + this.numRows = numRows; + this.numColumns = data.length; + } + + @Override + public boolean isSparse() { + return false; + } + + @Override + public boolean readOnly() { + return true; + } + + @Override + public int numRows() { + return numRows; + } + + @Override + public int numColumns() { + return numColumns; + } + + @Override + public int[] getRow(final int index) { + checkRowIndex(index, numRows); + + int[] row = new int[numColumns]; + return getRow(index, row); + } + + @Override + public int[] getRow(final int index, @Nonnull final int[] dst) { + checkRowIndex(index, numRows); + + for (int j = 0; j < data.length; j++) { + final int[] col = data[j]; + if (index < col.length) { + dst[j] = col[index]; + } + } + return dst; + } + + @Override + public int get(final int row, final int col, final int defaultValue) { + checkIndex(row, col, numRows, numColumns); + + final int[] colData = data[col]; + if (row >= colData.length) { + return defaultValue; + } + return colData[row]; + } + + @Override + public int getAndSet(final int row, final int col, final int value) { + checkIndex(row, col, numRows, numColumns); + + final int[] colData = data[col]; + checkRowIndex(row, colData.length); + + final int old = colData[row]; + colData[row] = value; + return old; + } + + @Override + public void set(final int row, final int col, final int value) { + checkIndex(row, col, numRows, numColumns); + if (value == 0) { + return; + } + + final int[] colData = data[col]; + checkRowIndex(row, colData.length); + colData[row] = value; + } + + @Override + public void incr(final int row, final int col, final int delta) { + checkIndex(row, col, numRows, numColumns); + + final int[] colData = data[col]; + checkRowIndex(row, colData.length); + + colData[row] += delta; + } + + @Override + public void eachInColumn(final int col, @Nonnull final VectorProcedure procedure, + final boolean nullOutput) { + checkColIndex(col, numColumns); + + final int[] colData = data[col]; + if (colData == null) { + if (nullOutput) { + for (int i = 0; i < numRows; i++) { + procedure.apply(i, defaultValue); + } + } + return; + } + + int row = 0; + for (int len = colData.length; row < len; row++) { + procedure.apply(row, colData[row]); + } + if (nullOutput) { + for (; row < numRows; row++) { + procedure.apply(row, defaultValue); + } + } + } + + @Override + public void eachNonZeroInColumn(final int col, @Nonnull final VectorProcedure procedure) { + checkColIndex(col, numColumns); + + final int[] colData = data[col]; + if (colData == null) { + return; + } + int row = 0; + for (int len = colData.length; row < len; row++) { + final int v = colData[row]; + if (v != 0) { + procedure.apply(row, v); + } + } + } + +}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorIntMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorIntMatrix.java b/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorIntMatrix.java new file mode 100644 index 0000000..e0b3b4b --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/ints/ColumnMajorIntMatrix.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix.ints; + +import hivemall.math.vector.VectorProcedure; + +public abstract class ColumnMajorIntMatrix extends AbstractIntMatrix { + + public ColumnMajorIntMatrix() { + super(); + } + + @Override + public void eachInRow(int row, VectorProcedure procedure, boolean nullOutput) { + throw new UnsupportedOperationException(); + } + + @Override + public void eachNonZeroInRow(int row, VectorProcedure procedure) { + throw new UnsupportedOperationException(); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/ints/DoKIntMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/ints/DoKIntMatrix.java b/core/src/main/java/hivemall/math/matrix/ints/DoKIntMatrix.java new file mode 100644 index 0000000..2bbd3b4 --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/ints/DoKIntMatrix.java @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix.ints; + +import hivemall.math.vector.VectorProcedure; +import hivemall.utils.collections.maps.Long2IntOpenHashTable; +import hivemall.utils.lang.Preconditions; +import hivemall.utils.lang.Primitives; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +/** + * Dictionary-of-Key Sparse Int Matrix. + */ +public final class DoKIntMatrix extends AbstractIntMatrix { + + @Nonnull + private final Long2IntOpenHashTable elements; + @Nonnegative + private int numRows; + @Nonnegative + private int numColumns; + + public DoKIntMatrix() { + this(0, 0); + } + + public DoKIntMatrix(@Nonnegative int numRows, @Nonnegative int numCols) { + this(numRows, numCols, 0.05f); + } + + public DoKIntMatrix(@Nonnegative int numRows, @Nonnegative int numCols, + @Nonnegative float sparsity) { + Preconditions.checkArgument(sparsity >= 0.f && sparsity <= 1.f, "Invalid Sparsity value: " + + sparsity); + int initialCapacity = Math.max(16384, Math.round(numRows * numCols * sparsity)); + this.elements = new Long2IntOpenHashTable(initialCapacity); + this.numRows = numRows; + this.numColumns = numCols; + } + + private DoKIntMatrix(@Nonnull Long2IntOpenHashTable elements, @Nonnegative int numRows, + @Nonnegative int numColumns) { + this.elements = elements; + this.numRows = numRows; + this.numColumns = numColumns; + } + + @Override + public boolean isSparse() { + return true; + } + + @Override + public boolean readOnly() { + return false; + } + + @Override + public int numRows() { + return numRows; + } + + @Override + public int numColumns() { + return numColumns; + } + + @Override + public int[] getRow(@Nonnegative final int index) { + int[] dst = row(); + return getRow(index, dst); + } + + @Override + public int[] getRow(@Nonnegative final int row, @Nonnull final int[] dst) { + checkRowIndex(row, numRows); + + final int end = Math.min(dst.length, numColumns); + for (int col = 0; col < end; col++) { + long index = index(row, col); + int v = elements.get(index, defaultValue); + dst[col] = v; + } + + return dst; + } + + @Override + public int get(@Nonnegative final int row, @Nonnegative final int col, final int defaultValue) { + checkIndex(row, col, numRows, numColumns); + + long index = index(row, col); + return elements.get(index, defaultValue); + } + + @Override + public void set(@Nonnegative final int row, @Nonnegative final int col, final int value) { + checkIndex(row, col); + + long index = index(row, col); + elements.put(index, value); + this.numRows = Math.max(numRows, row + 1); + this.numColumns = Math.max(numColumns, col + 1); + } + + @Override + public int getAndSet(@Nonnegative final int row, @Nonnegative final int col, final int value) { + checkIndex(row, col); + + long index = index(row, col); + int old = elements.put(index, value); + this.numRows = Math.max(numRows, row + 1); + this.numColumns = Math.max(numColumns, col + 1); + return old; + } + + @Override + public void incr(@Nonnegative final int row, @Nonnegative final int col, final int delta) { + checkIndex(row, col); + + long index = index(row, col); + elements.incr(index, delta); + this.numRows = Math.max(numRows, row + 1); + this.numColumns = Math.max(numColumns, col + 1); + } + + @Override + public void eachInRow(@Nonnegative final int row, @Nonnull final VectorProcedure procedure, + final boolean nullOutput) { + checkRowIndex(row, numRows); + + for (int col = 0; col < numColumns; col++) { + long i = index(row, col); + final int key = elements._findKey(i); + if (key < 0) { + if (nullOutput) { + procedure.apply(col, defaultValue); + } + } else { + int v = elements._get(key); + procedure.apply(col, v); + } + } + } + + @Override + public void eachNonZeroInRow(@Nonnegative final int row, + @Nonnull final VectorProcedure procedure) { + checkRowIndex(row, numRows); + + for (int col = 0; col < numColumns; col++) { + long i = index(row, col); + final int v = elements.get(i, 0); + if (v != 0) { + procedure.apply(col, v); + } + } + } + + @Override + public void eachInColumn(@Nonnegative final int col, @Nonnull final VectorProcedure procedure, + final boolean nullOutput) { + checkColIndex(col, numColumns); + + for (int row = 0; row < numRows; row++) { + long i = index(row, col); + final int key = elements._findKey(i); + if (key < 0) { + if (nullOutput) { + procedure.apply(row, defaultValue); + } + } else { + int v = elements._get(key); + procedure.apply(row, v); + } + } + } + + @Override + public void eachNonZeroInColumn(@Nonnegative final int col, + @Nonnull final VectorProcedure procedure) { + checkColIndex(col, numColumns); + + for (int row = 0; row < numRows; row++) { + long i = index(row, col); + final int v = elements.get(i, 0); + if (v != 0) { + procedure.apply(row, v); + } + } + } + + @Nonnegative + private static long index(@Nonnegative final int row, @Nonnegative final int col) { + return Primitives.toLong(row, col); + } + + @Nonnull + public static DoKIntMatrix build(@Nonnull final int[][] matrix, boolean rowMajorInput, + boolean nonZeroOnly) { + if (rowMajorInput) { + return buildFromRowMajorMatrix(matrix, nonZeroOnly); + } else { + return buildFromColumnMajorMatrix(matrix, nonZeroOnly); + } + } + + @Nonnull + private static DoKIntMatrix buildFromRowMajorMatrix(@Nonnull final int[][] rowMajorMatrix, + boolean nonZeroOnly) { + final Long2IntOpenHashTable elements = new Long2IntOpenHashTable(rowMajorMatrix.length * 3); + + int numRows = rowMajorMatrix.length, numColumns = 0; + for (int i = 0; i < rowMajorMatrix.length; i++) { + final int[] row = rowMajorMatrix[i]; + if (row == null) { + continue; + } + numColumns = Math.max(numColumns, row.length); + for (int col = 0; col < row.length; col++) { + int value = row[col]; + if (nonZeroOnly && value == 0) { + continue; + } + long index = index(i, col); + elements.put(index, value); + } + } + + return new DoKIntMatrix(elements, numRows, numColumns); + } + + @Nonnull + private static DoKIntMatrix buildFromColumnMajorMatrix( + @Nonnull final int[][] columnMajorMatrix, boolean nonZeroOnly) { + final Long2IntOpenHashTable elements = new Long2IntOpenHashTable( + columnMajorMatrix.length * 3); + + int numRows = 0, numColumns = columnMajorMatrix.length; + for (int j = 0; j < columnMajorMatrix.length; j++) { + final int[] col = columnMajorMatrix[j]; + if (col == null) { + continue; + } + numRows = Math.max(numRows, col.length); + for (int row = 0; row < col.length; row++) { + int value = col[row]; + if (nonZeroOnly && value == 0) { + continue; + } + long index = index(row, j); + elements.put(index, value); + } + } + + return new DoKIntMatrix(elements, numRows, numColumns); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/ints/IntMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/ints/IntMatrix.java b/core/src/main/java/hivemall/math/matrix/ints/IntMatrix.java new file mode 100644 index 0000000..bcc954e --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/ints/IntMatrix.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix.ints; + +import hivemall.math.vector.VectorProcedure; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +public interface IntMatrix { + + public boolean isSparse(); + + public boolean readOnly(); + + public void setDefaultValue(int value); + + @Nonnegative + public int numRows(); + + @Nonnegative + public int numColumns(); + + @Nonnull + public int[] row(); + + @Nonnull + public int[] getRow(@Nonnegative int index); + + /** + * @return returns dst + */ + @Nonnull + public int[] getRow(@Nonnegative int index, @Nonnull int[] dst); + + /** + * @throws IndexOutOfBoundsException + */ + public int get(@Nonnegative int row, @Nonnegative int col); + + /** + * @throws IndexOutOfBoundsException + */ + public int get(@Nonnegative int row, @Nonnegative int col, int defaultValue); + + /** + * @throws IndexOutOfBoundsException + * @throws UnsupportedOperationException + */ + public void set(@Nonnegative int row, @Nonnegative int col, int value); + + /** + * @throws IndexOutOfBoundsException + * @throws UnsupportedOperationException + */ + public int getAndSet(@Nonnegative int row, @Nonnegative int col, int value); + + /** + * @throws IndexOutOfBoundsException + * @throws UnsupportedOperationException + */ + public void incr(@Nonnegative int row, @Nonnegative int col); + + /** + * @throws IndexOutOfBoundsException + * @throws UnsupportedOperationException + */ + public void incr(@Nonnegative int row, @Nonnegative int col, int delta); + + public void eachInRow(@Nonnegative int row, @Nonnull VectorProcedure procedure); + + public void eachInRow(@Nonnegative int row, @Nonnull VectorProcedure procedure, + boolean nullOutput); + + public void eachNonNullInRow(@Nonnegative int row, @Nonnull VectorProcedure procedure); + + public void eachNonZeroInRow(@Nonnegative int row, @Nonnull VectorProcedure procedure); + + public void eachInColumn(@Nonnegative int col, @Nonnull VectorProcedure procedure); + + public void eachInColumn(@Nonnegative int col, @Nonnull VectorProcedure procedure, + boolean nullOutput); + + public void eachNonNullInColumn(@Nonnegative int col, @Nonnull VectorProcedure procedure); + + public void eachNonZeroInColumn(@Nonnegative int col, @Nonnull VectorProcedure procedure); + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java new file mode 100644 index 0000000..d2232b2 --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/sparse/CSCMatrix.java @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix.sparse; + +import hivemall.math.matrix.ColumnMajorMatrix; +import hivemall.math.matrix.builders.CSCMatrixBuilder; +import hivemall.math.vector.Vector; +import hivemall.math.vector.VectorProcedure; +import hivemall.utils.lang.ArrayUtils; +import hivemall.utils.lang.Preconditions; + +import java.util.Arrays; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +/** + * @link http://netlib.org/linalg/html_templates/node92.html#SECTION00931200000000000000 + */ +public final class CSCMatrix extends ColumnMajorMatrix { + + @Nonnull + private final int[] columnPointers; + @Nonnull + private final int[] rowIndicies; + @Nonnull + private final double[] values; + + private final int numRows; + private final int numColumns; + private final int nnz; + + public CSCMatrix(@Nonnull int[] columnPointers, @Nonnull int[] rowIndicies, + @Nonnull double[] values, int numRows, int numColumns) { + super(); + Preconditions.checkArgument(columnPointers.length >= 1, + "rowPointers must be greather than 0: " + columnPointers.length); + Preconditions.checkArgument(rowIndicies.length == values.length, "#rowIndicies (" + + rowIndicies.length + ") must be equals to #values (" + values.length + ")"); + this.columnPointers = columnPointers; + this.rowIndicies = rowIndicies; + this.values = values; + this.numRows = numRows; + this.numColumns = numColumns; + this.nnz = values.length; + } + + @Override + public boolean isSparse() { + return true; + } + + @Override + public boolean readOnly() { + return true; + } + + @Override + public boolean swappable() { + return false; + } + + @Override + public int nnz() { + return nnz; + } + + @Override + public int numRows() { + return numRows; + } + + @Override + public int numColumns() { + return numColumns; + } + + @Override + public int numColumns(final int row) { + checkRowIndex(row, numRows); + + return ArrayUtils.count(rowIndicies, row); + } + + @Override + public double[] getRow(int index) { + checkRowIndex(index, numRows); + + final double[] row = new double[numColumns]; + + final int numCols = columnPointers.length - 1; + for (int j = 0; j < numCols; j++) { + final int k = Arrays.binarySearch(rowIndicies, columnPointers[j], + columnPointers[j + 1], index); + if (k >= 0) { + row[j] = values[k]; + } + } + + return row; + } + + @Override + public double[] getRow(final int index, @Nonnull final double[] dst) { + checkRowIndex(index, numRows); + + final int last = Math.min(dst.length, columnPointers.length - 1); + for (int j = 0; j < last; j++) { + final int k = Arrays.binarySearch(rowIndicies, columnPointers[j], + columnPointers[j + 1], index); + if (k >= 0) { + dst[j] = values[k]; + } + } + + return dst; + } + + @Override + public void getRow(final int index, @Nonnull final Vector row) { + checkRowIndex(index, numRows); + row.clear(); + + for (int j = 0, last = columnPointers.length - 1; j < last; j++) { + final int k = Arrays.binarySearch(rowIndicies, columnPointers[j], + columnPointers[j + 1], index); + if (k >= 0) { + double v = values[k]; + row.set(j, v); + } + } + } + + @Override + public double get(final int row, final int col, final double defaultValue) { + checkIndex(row, col, numRows, numColumns); + + int index = getIndex(row, col); + if (index < 0) { + return defaultValue; + } + return values[index]; + } + + @Override + public double getAndSet(final int row, final int col, final double value) { + checkIndex(row, col, numRows, numColumns); + + final int index = getIndex(row, col); + if (index < 0) { + throw new UnsupportedOperationException("Cannot update value in row " + row + ", col " + + col); + } + + double old = values[index]; + values[index] = value; + return old; + } + + @Override + public void set(final int row, final int col, final double value) { + checkIndex(row, col, numRows, numColumns); + + final int index = getIndex(row, col); + if (index < 0) { + throw new UnsupportedOperationException("Cannot update value in row " + row + ", col " + + col); + } + values[index] = value; + } + + private int getIndex(@Nonnegative final int row, @Nonnegative final int col) { + int leftIn = columnPointers[col]; + int rightEx = columnPointers[col + 1]; + final int index = Arrays.binarySearch(rowIndicies, leftIn, rightEx, row); + if (index >= 0 && index >= values.length) { + throw new IndexOutOfBoundsException("Value index " + index + " out of range " + + values.length); + } + return index; + } + + @Override + public void swap(final int row1, final int row2) { + throw new UnsupportedOperationException(); + } + + @Override + public void eachInColumn(final int col, @Nonnull final VectorProcedure procedure, + final boolean nullOutput) { + checkColIndex(col, numColumns); + + final int startIn = columnPointers[col]; + final int endEx = columnPointers[col + 1]; + + if (nullOutput) { + for (int row = 0, i = startIn; row < numRows; row++) { + if (i < endEx && row == rowIndicies[i]) { + double v = values[i++]; + procedure.apply(row, v); + } else { + procedure.apply(row, 0.d); + } + } + } else { + for (int j = startIn; j < endEx; j++) { + int row = rowIndicies[j]; + double v = values[j]; + procedure.apply(row, v); + } + } + } + + @Override + public void eachNonZeroInColumn(final int col, @Nonnull final VectorProcedure procedure) { + checkColIndex(col, numColumns); + + final int startIn = columnPointers[col]; + final int endEx = columnPointers[col + 1]; + for (int j = startIn; j < endEx; j++) { + int row = rowIndicies[j]; + final double v = values[j]; + if (v != 0.d) { + procedure.apply(row, v); + } + } + } + + @Override + public CSRMatrix toRowMajorMatrix() { + final int[] rowPointers = new int[numRows + 1]; + final int[] colIndicies = new int[nnz]; + final double[] csrValues = new double[nnz]; + + // compute nnz per for each row + for (int i = 0; i < rowIndicies.length; i++) { + rowPointers[rowIndicies[i]]++; + } + for (int i = 0, sum = 0; i < numRows; i++) { + int curr = rowPointers[i]; + rowPointers[i] = sum; + sum += curr; + } + rowPointers[numRows] = nnz; + + for (int j = 0; j < numColumns; j++) { + for (int i = columnPointers[j], last = columnPointers[j + 1]; i < last; i++) { + int col = rowIndicies[i]; + int dst = rowPointers[col]; + + colIndicies[dst] = j; + csrValues[dst] = values[i]; + + rowPointers[col]++; + } + } + + // shift column pointers + for (int i = 0, last = 0; i <= numRows; i++) { + int tmp = rowPointers[i]; + rowPointers[i] = last; + last = tmp; + } + + return new CSRMatrix(rowPointers, colIndicies, csrValues, numColumns); + } + + @Override + public CSCMatrixBuilder builder() { + return new CSCMatrixBuilder(nnz); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java new file mode 100644 index 0000000..dd89521 --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/sparse/CSRMatrix.java @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix.sparse; + +import hivemall.math.matrix.RowMajorMatrix; +import hivemall.math.matrix.builders.CSRMatrixBuilder; +import hivemall.math.vector.VectorProcedure; +import hivemall.utils.lang.Preconditions; + +import java.util.Arrays; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +/** + * Read-only CSR double Matrix. + * + * @link http://netlib.org/linalg/html_templates/node91.html#SECTION00931100000000000000 + * @link http://www.cs.colostate.edu/~mcrob/toolbox/c++/sparseMatrix/sparse_matrix_compression.html + */ +public final class CSRMatrix extends RowMajorMatrix { + + @Nonnull + private final int[] rowPointers; + @Nonnull + private final int[] columnIndices; + @Nonnull + private final double[] values; + + @Nonnegative + private final int numRows; + @Nonnegative + private final int numColumns; + @Nonnegative + private final int nnz; + + public CSRMatrix(@Nonnull int[] rowPointers, @Nonnull int[] columnIndices, + @Nonnull double[] values, @Nonnegative int numColumns) { + super(); + Preconditions.checkArgument(rowPointers.length >= 1, + "rowPointers must be greather than 0: " + rowPointers.length); + Preconditions.checkArgument(columnIndices.length == values.length, "#columnIndices (" + + columnIndices.length + ") must be equals to #values (" + values.length + ")"); + this.rowPointers = rowPointers; + this.columnIndices = columnIndices; + this.values = values; + this.numRows = rowPointers.length - 1; + this.numColumns = numColumns; + this.nnz = values.length; + } + + @Override + public boolean isSparse() { + return true; + } + + @Override + public boolean readOnly() { + return true; + } + + @Override + public boolean swappable() { + return false; + } + + @Override + public int nnz() { + return nnz; + } + + @Override + public int numRows() { + return numRows; + } + + @Override + public int numColumns() { + return numColumns; + } + + @Override + public int numColumns(@Nonnegative final int row) { + checkRowIndex(row, numRows); + + int columns = rowPointers[row + 1] - rowPointers[row]; + return columns; + } + + @Override + public double[] getRow(@Nonnegative final int index) { + final double[] row = new double[numColumns]; + eachNonZeroInRow(index, new VectorProcedure() { + public void apply(int col, double value) { + row[col] = value; + } + }); + return row; + } + + @Override + public double[] getRow(@Nonnegative final int index, @Nonnull final double[] dst) { + Arrays.fill(dst, 0.d); + eachNonZeroInRow(index, new VectorProcedure() { + public void apply(int col, double value) { + checkColIndex(col, numColumns); + dst[col] = value; + } + }); + return dst; + } + + @Override + public double get(@Nonnegative final int row, @Nonnegative final int col, + final double defaultValue) { + checkIndex(row, col, numRows, numColumns); + + final int index = getIndex(row, col); + if (index < 0) { + return defaultValue; + } + return values[index]; + } + + @Override + public double getAndSet(@Nonnegative final int row, @Nonnegative final int col, + final double value) { + checkIndex(row, col, numRows, numColumns); + + final int index = getIndex(row, col); + if (index < 0) { + throw new UnsupportedOperationException("Cannot update value in row " + row + ", col " + + col); + } + + double old = values[index]; + values[index] = value; + return old; + } + + @Override + public void set(@Nonnegative final int row, @Nonnegative final int col, final double value) { + checkIndex(row, col, numRows, numColumns); + + final int index = getIndex(row, col); + if (index < 0) { + throw new UnsupportedOperationException("Cannot update value in row " + row + ", col " + + col); + } + values[index] = value; + } + + private int getIndex(@Nonnegative final int row, @Nonnegative final int col) { + int leftIn = rowPointers[row]; + int rightEx = rowPointers[row + 1]; + final int index = Arrays.binarySearch(columnIndices, leftIn, rightEx, col); + if (index >= 0 && index >= values.length) { + throw new IndexOutOfBoundsException("Value index " + index + " out of range " + + values.length); + } + return index; + } + + @Override + public void swap(int row1, int row2) { + throw new UnsupportedOperationException(); + } + + @Override + public void eachInRow(@Nonnegative final int row, @Nonnull final VectorProcedure procedure, + final boolean nullOutput) { + checkRowIndex(row, numRows); + + final int startIn = rowPointers[row]; + final int endEx = rowPointers[row + 1]; + + if (nullOutput) { + for (int col = 0, j = startIn; col < numColumns; col++) { + if (j < endEx && col == columnIndices[j]) { + double v = values[j++]; + procedure.apply(col, v); + } else { + procedure.apply(col, 0.d); + } + } + } else { + for (int i = startIn; i < endEx; i++) { + procedure.apply(columnIndices[i], values[i]); + } + } + } + + @Override + public void eachNonZeroInRow(@Nonnegative final int row, + @Nonnull final VectorProcedure procedure) { + checkRowIndex(row, numRows); + + final int startIn = rowPointers[row]; + final int endEx = rowPointers[row + 1]; + for (int i = startIn; i < endEx; i++) { + int col = columnIndices[i]; + final double v = values[i]; + if (v != 0.d) { + procedure.apply(col, v); + } + } + } + + @Override + public void eachColumnIndexInRow(@Nonnegative final int row, + @Nonnull final VectorProcedure procedure) { + checkRowIndex(row, numRows); + + final int startIn = rowPointers[row]; + final int endEx = rowPointers[row + 1]; + + for (int i = startIn; i < endEx; i++) { + procedure.apply(columnIndices[i]); + } + } + + @Nonnull + public CSCMatrix toColumnMajorMatrix() { + final int[] columnPointers = new int[numColumns + 1]; + final int[] rowIndicies = new int[nnz]; + final double[] cscValues = new double[nnz]; + + // compute nnz per for each column + for (int j = 0; j < columnIndices.length; j++) { + columnPointers[columnIndices[j]]++; + } + for (int j = 0, sum = 0; j < numColumns; j++) { + int curr = columnPointers[j]; + columnPointers[j] = sum; + sum += curr; + } + columnPointers[numColumns] = nnz; + + for (int i = 0; i < numRows; i++) { + for (int j = rowPointers[i], last = rowPointers[i + 1]; j < last; j++) { + int col = columnIndices[j]; + int dst = columnPointers[col]; + + rowIndicies[dst] = i; + cscValues[dst] = values[j]; + + columnPointers[col]++; + } + } + + // shift column pointers + for (int j = 0, last = 0; j <= numColumns; j++) { + int tmp = columnPointers[j]; + columnPointers[j] = last; + last = tmp; + } + + return new CSCMatrix(columnPointers, rowIndicies, cscValues, numRows, numColumns); + } + + @Override + public CSRMatrixBuilder builder() { + return new CSRMatrixBuilder(values.length); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java b/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java new file mode 100644 index 0000000..bcfd152 --- /dev/null +++ b/core/src/main/java/hivemall/math/matrix/sparse/DoKMatrix.java @@ -0,0 +1,332 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix.sparse; + +import hivemall.annotations.Experimental; +import hivemall.math.matrix.AbstractMatrix; +import hivemall.math.matrix.ColumnMajorMatrix; +import hivemall.math.matrix.RowMajorMatrix; +import hivemall.math.matrix.builders.DoKMatrixBuilder; +import hivemall.math.vector.Vector; +import hivemall.math.vector.VectorProcedure; +import hivemall.utils.collections.maps.Long2DoubleOpenHashTable; +import hivemall.utils.lang.Preconditions; +import hivemall.utils.lang.Primitives; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +@Experimental +public final class DoKMatrix extends AbstractMatrix { + + @Nonnull + private final Long2DoubleOpenHashTable elements; + @Nonnegative + private int numRows; + @Nonnegative + private int numColumns; + @Nonnegative + private int nnz; + + public DoKMatrix() { + this(0, 0); + } + + public DoKMatrix(@Nonnegative int numRows, @Nonnegative int numCols) { + this(numRows, numCols, 0.05f); + } + + public DoKMatrix(@Nonnegative int numRows, @Nonnegative int numCols, @Nonnegative float sparsity) { + super(); + Preconditions.checkArgument(sparsity >= 0.f && sparsity <= 1.f, "Invalid Sparsity value: " + + sparsity); + int initialCapacity = Math.max(16384, Math.round(numRows * numCols * sparsity)); + this.elements = new Long2DoubleOpenHashTable(initialCapacity); + elements.defaultReturnValue(0.d); + this.numRows = numRows; + this.numColumns = numCols; + this.nnz = 0; + } + + public DoKMatrix(@Nonnegative int initSize) { + super(); + int initialCapacity = Math.max(initSize, 16384); + this.elements = new Long2DoubleOpenHashTable(initialCapacity); + elements.defaultReturnValue(0.d); + this.numRows = 0; + this.numColumns = 0; + this.nnz = 0; + } + + @Override + public boolean isSparse() { + return true; + } + + @Override + public boolean isRowMajorMatrix() { + return false; + } + + @Override + public boolean isColumnMajorMatrix() { + return false; + } + + @Override + public boolean readOnly() { + return false; + } + + @Override + public boolean swappable() { + return true; + } + + @Override + public int nnz() { + return nnz; + } + + @Override + public int numRows() { + return numRows; + } + + @Override + public int numColumns() { + return numColumns; + } + + @Override + public int numColumns(@Nonnegative final int row) { + int count = 0; + for (int j = 0; j < numColumns; j++) { + long index = index(row, j); + if (elements.containsKey(index)) { + count++; + } + } + return count; + } + + @Override + public double[] getRow(@Nonnegative final int index) { + double[] dst = row(); + return getRow(index, dst); + } + + @Override + public double[] getRow(@Nonnegative final int row, @Nonnull final double[] dst) { + checkRowIndex(row, numRows); + + final int end = Math.min(dst.length, numColumns); + for (int col = 0; col < end; col++) { + long k = index(row, col); + double v = elements.get(k); + dst[col] = v; + } + + return dst; + } + + @Override + public void getRow(@Nonnegative final int index, @Nonnull final Vector row) { + checkRowIndex(index, numRows); + row.clear(); + + for (int col = 0; col < numColumns; col++) { + long k = index(index, col); + final double v = elements.get(k, 0.d); + if (v != 0.d) { + row.set(col, v); + } + } + } + + @Override + public double get(@Nonnegative final int row, @Nonnegative final int col, + final double defaultValue) { + checkIndex(row, col, numRows, numColumns); + + long index = index(row, col); + return elements.get(index, defaultValue); + } + + @Override + public void set(@Nonnegative final int row, @Nonnegative final int col, final double value) { + checkIndex(row, col); + + if (value == 0.d) { + return; + } + + long index = index(row, col); + if (elements.put(index, value, 0.d) == 0.d) { + nnz++; + this.numRows = Math.max(numRows, row + 1); + this.numColumns = Math.max(numColumns, col + 1); + } + } + + @Override + public double getAndSet(@Nonnegative final int row, @Nonnegative final int col, + final double value) { + checkIndex(row, col); + + long index = index(row, col); + double old = elements.put(index, value, 0.d); + if (old == 0.d) { + nnz++; + this.numRows = Math.max(numRows, row + 1); + this.numColumns = Math.max(numColumns, col + 1); + } + return old; + } + + @Override + public void swap(@Nonnegative final int row1, @Nonnegative final int row2) { + checkRowIndex(row1, numRows); + checkRowIndex(row2, numRows); + + for (int j = 0; j < numColumns; j++) { + final long i1 = index(row1, j); + final long i2 = index(row2, j); + + final int k1 = elements._findKey(i1); + final int k2 = elements._findKey(i2); + + if (k1 >= 0) { + if (k2 >= 0) { + double v1 = elements._get(k1); + double v2 = elements._set(k2, v1); + elements._set(k1, v2); + } else {// k1>=0 and k2<0 + double v1 = elements._remove(k1); + elements.put(i2, v1); + } + } else if (k2 >= 0) {// k2>=0 and k1 < 0 + double v2 = elements._remove(k2); + elements.put(i1, v2); + } else {//k1<0 and k2<0 + continue; + } + } + } + + @Override + public void eachInRow(@Nonnegative final int row, @Nonnull final VectorProcedure procedure, + final boolean nullOutput) { + checkRowIndex(row, numRows); + + for (int col = 0; col < numColumns; col++) { + long i = index(row, col); + final int key = elements._findKey(i); + if (key < 0) { + if (nullOutput) { + procedure.apply(col, 0.d); + } + } else { + double v = elements._get(key); + procedure.apply(col, v); + } + } + } + + @Override + public void eachNonZeroInRow(@Nonnegative final int row, + @Nonnull final VectorProcedure procedure) { + checkRowIndex(row, numRows); + + for (int col = 0; col < numColumns; col++) { + long i = index(row, col); + final double v = elements.get(i, 0.d); + if (v != 0.d) { + procedure.apply(col, v); + } + } + } + + @Override + public void eachColumnIndexInRow(int row, VectorProcedure procedure) { + checkRowIndex(row, numRows); + + for (int col = 0; col < numColumns; col++) { + long i = index(row, col); + final int key = elements._findKey(i); + if (key != -1) { + procedure.apply(col); + } + } + } + + @Override + public void eachInColumn(@Nonnegative final int col, @Nonnull final VectorProcedure procedure, + final boolean nullOutput) { + checkColIndex(col, numColumns); + + for (int row = 0; row < numRows; row++) { + long i = index(row, col); + final int key = elements._findKey(i); + if (key < 0) { + if (nullOutput) { + procedure.apply(row, 0.d); + } + } else { + double v = elements._get(key); + procedure.apply(row, v); + } + } + } + + @Override + public void eachNonZeroInColumn(@Nonnegative final int col, + @Nonnull final VectorProcedure procedure) { + checkColIndex(col, numColumns); + + for (int row = 0; row < numRows; row++) { + long i = index(row, col); + final double v = elements.get(i, 0.d); + if (v != 0.d) { + procedure.apply(row, v); + } + } + } + + @Override + public RowMajorMatrix toRowMajorMatrix() { + throw new UnsupportedOperationException("Not yet supported"); + } + + @Override + public ColumnMajorMatrix toColumnMajorMatrix() { + throw new UnsupportedOperationException("Not yet supported"); + } + + @Override + public DoKMatrixBuilder builder() { + return new DoKMatrixBuilder(elements.size()); + } + + @Nonnegative + private static long index(@Nonnegative final int row, @Nonnegative final int col) { + return Primitives.toLong(row, col); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/random/CommonsMathRandom.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/random/CommonsMathRandom.java b/core/src/main/java/hivemall/math/random/CommonsMathRandom.java new file mode 100644 index 0000000..e0b7554 --- /dev/null +++ b/core/src/main/java/hivemall/math/random/CommonsMathRandom.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.random; + +import javax.annotation.Nonnull; + +import org.apache.commons.math3.random.MersenneTwister; +import org.apache.commons.math3.random.RandomGenerator; + +public final class CommonsMathRandom implements PRNG { + + @Nonnull + private final RandomGenerator rng; + + public CommonsMathRandom() { + this.rng = new MersenneTwister(); + } + + public CommonsMathRandom(long seed) { + this.rng = new MersenneTwister(seed); + } + + public CommonsMathRandom(@Nonnull RandomGenerator rng) { + this.rng = rng; + } + + @Override + public int nextInt(final int n) { + return rng.nextInt(n); + } + + @Override + public int nextInt() { + return rng.nextInt(); + } + + @Override + public long nextLong() { + return rng.nextLong(); + } + + @Override + public double nextDouble() { + return rng.nextDouble(); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/random/JavaRandom.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/random/JavaRandom.java b/core/src/main/java/hivemall/math/random/JavaRandom.java new file mode 100644 index 0000000..f0ed4c7 --- /dev/null +++ b/core/src/main/java/hivemall/math/random/JavaRandom.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.random; + +import java.util.Random; + +import javax.annotation.Nonnull; + +public final class JavaRandom implements PRNG { + + private final Random rand; + + public JavaRandom() { + this.rand = new Random(); + } + + public JavaRandom(long seed) { + this.rand = new Random(seed); + } + + public JavaRandom(@Nonnull Random rand) { + this.rand = rand; + } + + @Override + public int nextInt(int n) { + return rand.nextInt(n); + } + + @Override + public int nextInt() { + return rand.nextInt(); + } + + @Override + public long nextLong() { + return rand.nextLong(); + } + + @Override + public double nextDouble() { + return rand.nextDouble(); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/random/PRNG.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/random/PRNG.java b/core/src/main/java/hivemall/math/random/PRNG.java new file mode 100644 index 0000000..d42dcfb --- /dev/null +++ b/core/src/main/java/hivemall/math/random/PRNG.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.random; + +import javax.annotation.Nonnegative; + +/** + * @link https://en.wikipedia.org/wiki/Pseudorandom_number_generator + */ +public interface PRNG { + + /** + * Returns a random integer in [0, n). + */ + public int nextInt(@Nonnegative int n); + + public int nextInt(); + + public long nextLong(); + + public double nextDouble(); + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/random/RandomNumberGeneratorFactory.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/random/RandomNumberGeneratorFactory.java b/core/src/main/java/hivemall/math/random/RandomNumberGeneratorFactory.java new file mode 100644 index 0000000..8843f7e --- /dev/null +++ b/core/src/main/java/hivemall/math/random/RandomNumberGeneratorFactory.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.random; + +import hivemall.utils.lang.Primitives; + +import java.security.SecureRandom; + +import javax.annotation.Nonnull; + +public final class RandomNumberGeneratorFactory { + + private RandomNumberGeneratorFactory() {} + + @Nonnull + public static PRNG createPRNG() { + return createPRNG(PRNGType.smile); + } + + @Nonnull + public static PRNG createPRNG(long seed) { + return createPRNG(PRNGType.smile, seed); + } + + @Nonnull + public static PRNG createPRNG(@Nonnull PRNGType type) { + final PRNG rng; + switch (type) { + case java: + rng = new JavaRandom(); + break; + case secure: + rng = new JavaRandom(new SecureRandom()); + break; + case smile: + rng = new SmileRandom(); + break; + case smileMT: + rng = new SmileRandom(new smile.math.random.MersenneTwister()); + break; + case smileMT64: + rng = new SmileRandom(new smile.math.random.MersenneTwister64()); + break; + case commonsMath3MT: + rng = new CommonsMathRandom(new org.apache.commons.math3.random.MersenneTwister()); + break; + default: + throw new IllegalStateException("Unexpected type: " + type); + } + return rng; + } + + @Nonnull + public static PRNG createPRNG(@Nonnull PRNGType type, long seed) { + final PRNG rng; + switch (type) { + case java: + rng = new JavaRandom(seed); + break; + case secure: + rng = new JavaRandom(new SecureRandom(Primitives.toBytes(seed))); + break; + case smile: + rng = new SmileRandom(seed); + break; + case smileMT: + rng = new SmileRandom(new smile.math.random.MersenneTwister( + Primitives.hashCode(seed))); + break; + case smileMT64: + rng = new SmileRandom(new smile.math.random.MersenneTwister64(seed)); + break; + case commonsMath3MT: + rng = new CommonsMathRandom(new org.apache.commons.math3.random.MersenneTwister( + seed)); + break; + default: + throw new IllegalStateException("Unexpected type: " + type); + } + return rng; + } + + public enum PRNGType { + java, secure, smile, smileMT, smileMT64, commonsMath3MT; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/random/SmileRandom.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/random/SmileRandom.java b/core/src/main/java/hivemall/math/random/SmileRandom.java new file mode 100644 index 0000000..1edc56c --- /dev/null +++ b/core/src/main/java/hivemall/math/random/SmileRandom.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.random; + +import javax.annotation.Nonnull; + +import smile.math.random.RandomNumberGenerator; +import smile.math.random.UniversalGenerator; + +public final class SmileRandom implements PRNG { + + @Nonnull + private RandomNumberGenerator rng; + + public SmileRandom() { + this.rng = new UniversalGenerator(); + } + + public SmileRandom(long seed) { + this.rng = new UniversalGenerator(seed); + } + + public SmileRandom(@Nonnull RandomNumberGenerator rng) { + this.rng = rng; + } + + @Override + public int nextInt(int n) { + return rng.nextInt(n); + } + + @Override + public int nextInt() { + return rng.nextInt(); + } + + @Override + public long nextLong() { + return rng.nextLong(); + } + + @Override + public double nextDouble() { + return rng.nextDouble(); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/vector/AbstractVector.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/vector/AbstractVector.java b/core/src/main/java/hivemall/math/vector/AbstractVector.java new file mode 100644 index 0000000..88bed7b --- /dev/null +++ b/core/src/main/java/hivemall/math/vector/AbstractVector.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.vector; + +import javax.annotation.Nonnegative; + +public abstract class AbstractVector implements Vector { + + public AbstractVector() {} + + @Override + public double get(@Nonnegative final int index) { + return get(index, 0.d); + } + + protected static final void checkIndex(final int index) { + if (index < 0) { + throw new IndexOutOfBoundsException("Invalid index " + index); + } + } + + protected static final void checkIndex(final int index, final int size) { + if (index < 0 || index >= size) { + throw new IndexOutOfBoundsException("Index " + index + " out of bounds " + size); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/vector/DenseVector.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/vector/DenseVector.java b/core/src/main/java/hivemall/math/vector/DenseVector.java new file mode 100644 index 0000000..bd39af1 --- /dev/null +++ b/core/src/main/java/hivemall/math/vector/DenseVector.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.vector; + +import java.util.Arrays; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +public final class DenseVector extends AbstractVector { + + @Nonnull + private final double[] values; + private final int size; + + public DenseVector(@Nonnegative int size) { + super(); + this.values = new double[size]; + this.size = size; + } + + public DenseVector(@Nonnull double[] values) { + super(); + this.values = values; + this.size = values.length; + } + + @Override + public double get(@Nonnegative final int index, final double defaultValue) { + checkIndex(index); + if (index >= size) { + return defaultValue; + } + + return values[index]; + } + + @Override + public void set(@Nonnegative final int index, final double value) { + checkIndex(index, size); + + values[index] = value; + } + + @Override + public void incr(@Nonnegative final int index, final double delta) { + checkIndex(index, size); + + values[index] += delta; + } + + @Override + public void each(@Nonnull final VectorProcedure procedure) { + for (int i = 0; i < values.length; i++) { + procedure.apply(i, values[i]); + } + } + + @Override + public int size() { + return size; + } + + @Override + public void clear() { + Arrays.fill(values, 0.d); + } + + @Override + public double[] toArray() { + return values; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/vector/SparseVector.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/vector/SparseVector.java b/core/src/main/java/hivemall/math/vector/SparseVector.java new file mode 100644 index 0000000..072b544 --- /dev/null +++ b/core/src/main/java/hivemall/math/vector/SparseVector.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.vector; + +import hivemall.utils.collections.arrays.SparseDoubleArray; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +public final class SparseVector extends AbstractVector { + + @Nonnull + private final SparseDoubleArray values; + + public SparseVector() { + super(); + this.values = new SparseDoubleArray(); + } + + public SparseVector(@Nonnull SparseDoubleArray values) { + super(); + this.values = values; + } + + @Override + public double get(@Nonnegative final int index, final double defaultValue) { + return values.get(index, defaultValue); + } + + @Override + public void set(@Nonnegative final int index, final double value) { + values.put(index, value); + } + + @Override + public void incr(@Nonnegative final int index, final double delta) { + values.increment(index, delta); + } + + @Override + public void each(@Nonnull final VectorProcedure procedure) { + values.each(procedure); + } + + @Override + public int size() { + return values.size(); + } + + @Override + public void clear() { + values.clear(); + } + + @Override + public double[] toArray() { + return values.toArray(); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/vector/Vector.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/vector/Vector.java b/core/src/main/java/hivemall/math/vector/Vector.java new file mode 100644 index 0000000..2e5107d --- /dev/null +++ b/core/src/main/java/hivemall/math/vector/Vector.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.vector; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +public interface Vector { + + public double get(@Nonnegative int index); + + public double get(@Nonnegative int index, double defaultValue); + + /** + * @throws UnsupportedOperationException + */ + public void set(@Nonnegative int index, double value); + + public void incr(@Nonnegative int index, double delta); + + public void each(@Nonnull VectorProcedure procedure); + + public int size(); + + public void clear(); + + @Nonnull + public double[] toArray(); + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/math/vector/VectorProcedure.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/math/vector/VectorProcedure.java b/core/src/main/java/hivemall/math/vector/VectorProcedure.java new file mode 100644 index 0000000..266c531 --- /dev/null +++ b/core/src/main/java/hivemall/math/vector/VectorProcedure.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.vector; + +import javax.annotation.Nonnegative; + +public abstract class VectorProcedure { + + public VectorProcedure() {} + + public void apply(@Nonnegative int i, double value) {} + + public void apply(@Nonnegative int i, int value) {} + + public void apply(@Nonnegative int i) {} + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/matrix/CSRMatrixBuilder.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/matrix/CSRMatrixBuilder.java b/core/src/main/java/hivemall/matrix/CSRMatrixBuilder.java deleted file mode 100644 index d2deda1..0000000 --- a/core/src/main/java/hivemall/matrix/CSRMatrixBuilder.java +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.matrix; - -import hivemall.utils.collections.DoubleArrayList; -import hivemall.utils.collections.IntArrayList; - -import javax.annotation.Nonnegative; -import javax.annotation.Nonnull; - -/** - * Compressed Sparse Row Matrix. - * - * @link http://netlib.org/linalg/html_templates/node91.html#SECTION00931100000000000000 - * @link http://www.cs.colostate.edu/~mcrob/toolbox/c++/sparseMatrix/sparse_matrix_compression.html - */ -public final class CSRMatrixBuilder extends MatrixBuilder { - - @Nonnull - private final IntArrayList rowPointers; - @Nonnull - private final IntArrayList columnIndices; - @Nonnull - private final DoubleArrayList values; - - private int maxNumColumns; - - public CSRMatrixBuilder(int initSize) { - super(); - this.rowPointers = new IntArrayList(initSize + 1); - rowPointers.add(0); - this.columnIndices = new IntArrayList(initSize); - this.values = new DoubleArrayList(initSize); - this.maxNumColumns = 0; - } - - @Override - public CSRMatrixBuilder nextRow() { - int ptr = values.size(); - rowPointers.add(ptr); - return this; - } - - @Override - public CSRMatrixBuilder nextColumn(@Nonnegative int col, double value) { - if (value == 0.d) { - return this; - } - - columnIndices.add(col); - values.add(value); - this.maxNumColumns = Math.max(col + 1, maxNumColumns); - return this; - } - - @Override - public Matrix buildMatrix(boolean readOnly) { - if (!readOnly) { - throw new UnsupportedOperationException("Only readOnly matrix is supported"); - } - - ReadOnlyCSRMatrix matrix = new ReadOnlyCSRMatrix(rowPointers.toArray(true), - columnIndices.toArray(true), values.toArray(true), maxNumColumns); - return matrix; - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/matrix/DenseMatrixBuilder.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/matrix/DenseMatrixBuilder.java b/core/src/main/java/hivemall/matrix/DenseMatrixBuilder.java deleted file mode 100644 index f70616e..0000000 --- a/core/src/main/java/hivemall/matrix/DenseMatrixBuilder.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.matrix; - -import hivemall.utils.collections.SparseDoubleArray; - -import java.util.ArrayList; -import java.util.List; - -import javax.annotation.Nonnegative; -import javax.annotation.Nonnull; - -public final class DenseMatrixBuilder extends MatrixBuilder { - - @Nonnull - private final List<double[]> rows; - private int maxNumColumns; - - @Nonnull - private final SparseDoubleArray rowProbe; - - public DenseMatrixBuilder(int initSize) { - super(); - this.rows = new ArrayList<double[]>(initSize); - this.maxNumColumns = 0; - this.rowProbe = new SparseDoubleArray(32); - } - - @Override - public MatrixBuilder nextColumn(@Nonnegative final int col, final double value) { - if (value == 0.d) { - return this; - } - rowProbe.put(col, value); - return this; - } - - @Override - public MatrixBuilder nextRow() { - double[] row = rowProbe.toArray(); - rowProbe.clear(); - nextRow(row); - return this; - } - - @Override - public void nextRow(@Nonnull double[] row) { - rows.add(row); - this.maxNumColumns = Math.max(row.length, maxNumColumns); - } - - @Override - public Matrix buildMatrix(boolean readOnly) { - if (!readOnly) { - throw new UnsupportedOperationException("Only readOnly matrix is supported"); - } - - int numRows = rows.size(); - double[][] data = rows.toArray(new double[numRows][]); - return new ReadOnlyDenseMatrix2d(data, maxNumColumns); - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/matrix/Matrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/matrix/Matrix.java b/core/src/main/java/hivemall/matrix/Matrix.java deleted file mode 100644 index 8bbb6c5..0000000 --- a/core/src/main/java/hivemall/matrix/Matrix.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.matrix; - -import javax.annotation.Nonnegative; - -public abstract class Matrix { - - private double defaultValue; - - public Matrix() { - this.defaultValue = 0.d; - } - - public abstract boolean readOnly(); - - public void setDefaultValue(double value) { - this.defaultValue = value; - } - - @Nonnegative - public abstract int numRows(); - - @Nonnegative - public abstract int numColumns(); - - @Nonnegative - public abstract int numColumns(@Nonnegative int row); - - /** - * @throws IndexOutOfBoundsException - */ - public final double get(@Nonnegative final int row, @Nonnegative final int col) { - return get(row, col, defaultValue); - } - - /** - * @throws IndexOutOfBoundsException - */ - public abstract double get(@Nonnegative int row, @Nonnegative int col, double defaultValue); - - /** - * @throws IndexOutOfBoundsException - * @throws UnsupportedOperationException - */ - public abstract void set(@Nonnegative int row, @Nonnegative int col, double value); - - /** - * @throws IndexOutOfBoundsException - * @throws UnsupportedOperationException - */ - public abstract double getAndSet(@Nonnegative int row, @Nonnegative final int col, double value); - - protected static final void checkRowIndex(final int row, final int numRows) { - if (row < 0 || row >= numRows) { - throw new IndexOutOfBoundsException("Row index " + row + " out of bounds " + numRows); - } - } - - protected static final void checkColIndex(final int col, final int numColumns) { - if (col < 0 || col >= numColumns) { - throw new IndexOutOfBoundsException("Col index " + col + " out of bounds " + numColumns); - } - } - - protected static final void checkIndex(final int row, final int col, final int numRows, - final int numColumns) { - if (row < 0 || row >= numRows) { - throw new IndexOutOfBoundsException("Row index " + row + " out of bounds " + numRows); - } - if (col < 0 || col >= numColumns) { - throw new IndexOutOfBoundsException("Col index " + col + " out of bounds " + numColumns); - } - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/matrix/MatrixBuilder.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/matrix/MatrixBuilder.java b/core/src/main/java/hivemall/matrix/MatrixBuilder.java deleted file mode 100644 index e4d6233..0000000 --- a/core/src/main/java/hivemall/matrix/MatrixBuilder.java +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.matrix; - -import javax.annotation.Nonnegative; -import javax.annotation.Nonnull; - -public abstract class MatrixBuilder { - - public MatrixBuilder() {} - - public void nextRow(@Nonnull final double[] row) { - for (int col = 0; col < row.length; col++) { - nextColumn(col, row[col]); - } - nextRow(); - } - - public void nextRow(@Nonnull final String[] row) { - for (String col : row) { - if (col == null) { - continue; - } - nextColumn(col); - } - nextRow(); - } - - @Nonnull - public abstract MatrixBuilder nextRow(); - - @Nonnull - public abstract MatrixBuilder nextColumn(@Nonnegative int col, double value); - - /** - * @throws IllegalArgumentException - * @throws NumberFormatException - */ - @Nonnull - public MatrixBuilder nextColumn(@Nonnull final String col) { - final int pos = col.indexOf(':'); - if (pos == 0) { - throw new IllegalArgumentException("Invalid feature value representation: " + col); - } - - final String feature; - final double value; - if (pos > 0) { - feature = col.substring(0, pos); - String s2 = col.substring(pos + 1); - value = Double.parseDouble(s2); - } else { - feature = col; - value = 1.d; - } - - if (feature.indexOf(':') != -1) { - throw new IllegalArgumentException("Invaliad feature format `<index>:<value>`: " + col); - } - - int colIndex = Integer.parseInt(feature); - if (colIndex < 0) { - throw new IllegalArgumentException("Col index MUST be greather than or equals to 0: " - + colIndex); - } - - return nextColumn(colIndex, value); - } - - @Nonnull - public abstract Matrix buildMatrix(boolean readOnly); - -}