Github user takuti commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/121#discussion_r144199528
--- Diff: core/src/main/java/hivemall/math/matrix/MatrixUtils.java ---
@@ -70,4 +77,259 @@ public void apply(int i, int value) {
return which.getValue();
}
+ /**
+ * @param data non-zero entries
+ */
+ @Nonnull
+ public static CSRMatrix coo2csr(@Nonnull final int[] rows, @Nonnull
final int[] cols,
+ @Nonnull final double[] data, @Nonnegative final int numRows,
+ @Nonnegative final int numCols, final boolean sortColumns) {
+ final int nnz = data.length;
+ Preconditions.checkArgument(rows.length == nnz);
+ Preconditions.checkArgument(cols.length == nnz);
+
+ final int[] rowPointers = new int[numRows + 1];
+ final int[] colIndicies = new int[nnz];
+ final double[] values = new double[nnz];
+
+ coo2csr(rows, cols, data, rowPointers, colIndicies, values,
numRows, numCols, nnz);
+
+ if (sortColumns) {
+ sortIndicies(rowPointers, colIndicies, values);
+ }
+ return new CSRMatrix(rowPointers, colIndicies, values, numCols);
+ }
+
+ /**
+ * @param data non-zero entries
+ */
+ @Nonnull
+ public static CSRFloatMatrix coo2csr(@Nonnull final int[] rows,
@Nonnull final int[] cols,
+ @Nonnull final float[] data, @Nonnegative final int numRows,
+ @Nonnegative final int numCols, final boolean sortColumns) {
+ final int nnz = data.length;
+ Preconditions.checkArgument(rows.length == nnz);
+ Preconditions.checkArgument(cols.length == nnz);
+
+ final int[] rowPointers = new int[numRows + 1];
+ final int[] colIndicies = new int[nnz];
+ final float[] values = new float[nnz];
+
+ coo2csr(rows, cols, data, rowPointers, colIndicies, values,
numRows, numCols, nnz);
+
+ if (sortColumns) {
+ sortIndicies(rowPointers, colIndicies, values);
+ }
+ return new CSRFloatMatrix(rowPointers, colIndicies, values,
numCols);
+ }
+
+ @Nonnull
+ public static CSCMatrix coo2csc(@Nonnull final int[] rows, @Nonnull
final int[] cols,
+ @Nonnull final double[] data, @Nonnegative final int numRows,
+ @Nonnegative final int numCols, final boolean sortRows) {
+ final int nnz = data.length;
+ Preconditions.checkArgument(rows.length == nnz);
+ Preconditions.checkArgument(cols.length == nnz);
+
+ final int[] columnPointers = new int[numCols + 1];
+ final int[] rowIndicies = new int[nnz];
+ final double[] values = new double[nnz];
+
+ coo2csr(cols, rows, data, columnPointers, rowIndicies, values,
numCols, numRows, nnz);
+
+ if (sortRows) {
+ sortIndicies(columnPointers, rowIndicies, values);
+ }
+ return new CSCMatrix(columnPointers, rowIndicies, values, numRows,
numCols);
+ }
+
+ @Nonnull
+ public static CSCFloatMatrix coo2csc(@Nonnull final int[] rows,
@Nonnull final int[] cols,
+ @Nonnull final float[] data, @Nonnegative final int numRows,
+ @Nonnegative final int numCols, final boolean sortRows) {
+ final int nnz = data.length;
+ Preconditions.checkArgument(rows.length == nnz);
+ Preconditions.checkArgument(cols.length == nnz);
+
+ final int[] columnPointers = new int[numCols + 1];
+ final int[] rowIndicies = new int[nnz];
+ final float[] values = new float[nnz];
+
+ coo2csr(cols, rows, data, columnPointers, rowIndicies, values,
numCols, numRows, nnz);
+
+ if (sortRows) {
+ sortIndicies(columnPointers, rowIndicies, values);
+ }
+
+ return new CSCFloatMatrix(columnPointers, rowIndicies, values,
numRows, numCols);
+ }
+
+ private static void coo2csr(@Nonnull final int[] rows, @Nonnull final
int[] cols,
+ @Nonnull final double[] data, @Nonnull final int[] rowPointers,
+ @Nonnull final int[] colIndicies, @Nonnull final double[]
values,
+ @Nonnegative final int numRows, @Nonnegative final int
numCols, final int nnz) {
+ // compute nnz per for each row to get rowPointers
+ for (int n = 0; n < nnz; n++) {
+ rowPointers[rows[n]]++;
+ }
+ for (int i = 0, sum = 0; i < numRows; i++) {
+ int curr = rowPointers[i];
+ rowPointers[i] = sum;
+ sum += curr;
+ }
+ rowPointers[numRows] = nnz;
+
+ // copy cols, data to colIndicies, csrValues
+ for (int n = 0; n < nnz; n++) {
+ int row = rows[n];
+ int dst = rowPointers[row];
+
+ colIndicies[dst] = cols[n];
+ values[dst] = data[n];
+
+ rowPointers[row]++;
+ }
+
+ for (int i = 0, last = 0; i <= numRows; i++) {
+ int tmp = rowPointers[i];
+ rowPointers[i] = last;
+ last = tmp;
+ }
+ }
+
+ private static void coo2csr(@Nonnull final int[] rows, @Nonnull final
int[] cols,
+ @Nonnull final float[] data, @Nonnull final int[] rowPointers,
+ @Nonnull final int[] colIndicies, @Nonnull final float[]
values,
+ @Nonnegative final int numRows, @Nonnegative final int
numCols, final int nnz) {
+ // compute nnz per for each row to get rowPointers
+ for (int n = 0; n < nnz; n++) {
+ rowPointers[rows[n]]++;
+ }
+ for (int i = 0, sum = 0; i < numRows; i++) {
+ int curr = rowPointers[i];
+ rowPointers[i] = sum;
+ sum += curr;
+ }
+ rowPointers[numRows] = nnz;
+
+ // copy cols, data to colIndicies, csrValues
+ for (int n = 0; n < nnz; n++) {
+ int row = rows[n];
+ int dst = rowPointers[row];
+
+ colIndicies[dst] = cols[n];
+ values[dst] = data[n];
+
+ rowPointers[row]++;
+ }
+
+ for (int i = 0, last = 0; i <= numRows; i++) {
+ int tmp = rowPointers[i];
+ rowPointers[i] = last;
+ last = tmp;
+ }
+ }
+
+ private static void sortIndicies(@Nonnull final int[] rowPointers,
+ @Nonnull final int[] colIndicies, @Nonnull final double[]
values) {
+ final int numRows = rowPointers.length - 1;
+ if (numRows <= 1) {
+ return;
+ }
+
+ for (int i = 0; i < numRows; i++) {
+ final int rowStart = rowPointers[i];
+ final int rowEnd = rowPointers[i + 1];
+
+ final int numCols = rowEnd - rowStart;
+ if (numCols == 0) {
+ continue;
+ } else if (numCols < 0) {
+ throw new IllegalArgumentException(
+ "numCols SHOULD be greater than zero. numCols = rowEnd
- rowStart = " + rowEnd
+ + " - " + rowStart + " = " + numCols + " at
i=" + i);
+ }
+
+ final IntDoublePair[] pairs = new IntDoublePair[numCols];
--- End diff --
Why don't you use existing `hivemall.utils.struct.Pair` instead of newly
introduced `IntDoublePair` (and `IntFloatPair`) instance?
We frequently like to use pair/tuple-ish data structure, so I feel using
the same interface `struct.Pair` as much as we can is a better idea.
```java
final List<Pair<Integer, Double>> pairs = new ArrayList<Pair<Integer,
Double>>();
for (int jj = rowStart; jj < rowEnd; jj++) {
pairs.add(Pair.of(colIndicies[jj], values[jj]));
}
Collections.sort(pairs, new Comparator<Pair<Integer, Double>>() {
@Override
public int compare(Pair<Integer, Double> x, Pair<Integer, Double> y) {
return Integer.compare(x.getKey(), y.getKey());
}
});
for (int jj = rowStart, n = 0; jj < rowEnd; jj++, n++) {
Pair<Integer, Double> tmp = pairs.get(n);
colIndicies[jj] = tmp.getKey();
values[jj] = tmp.getValue();
}
```
---