Repository: incubator-hivemall Updated Branches: refs/heads/master 50867c178 -> 054a697eb
Close #48: [HIVEMALL-77] Support CSRMatrix and DenseMatrix Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/054a697e Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/054a697e Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/054a697e Branch: refs/heads/master Commit: 054a697eb27f6fa4c004228c20a85679fb2dca30 Parents: 50867c1 Author: Makoto Yui <[email protected]> Authored: Thu Feb 23 19:30:58 2017 +0900 Committer: myui <[email protected]> Committed: Thu Feb 23 19:30:58 2017 +0900 ---------------------------------------------------------------------- .../java/hivemall/matrix/CSRMatrixBuilder.java | 83 +++++ .../hivemall/matrix/DenseMatrixBuilder.java | 79 +++++ core/src/main/java/hivemall/matrix/Matrix.java | 92 ++++++ .../java/hivemall/matrix/MatrixBuilder.java | 89 +++++ .../java/hivemall/matrix/ReadOnlyCSRMatrix.java | 135 ++++++++ .../hivemall/matrix/ReadOnlyDenseMatrix2d.java | 102 ++++++ .../main/java/hivemall/model/FeatureValue.java | 9 +- .../hivemall/utils/collections/DoubleArray.java | 43 +++ .../utils/collections/DoubleArrayList.java | 20 +- .../utils/collections/FixedIntArray.java | 19 ++ .../hivemall/utils/collections/IntArray.java | 16 +- .../utils/collections/IntArrayList.java | 20 +- .../utils/collections/SparseDoubleArray.java | 213 ++++++++++++ .../utils/collections/SparseIntArray.java | 35 ++ .../java/hivemall/utils/lang/ArrayUtils.java | 95 ++++-- .../java/hivemall/matrix/MatrixBuilderTest.java | 329 +++++++++++++++++++ .../utils/collections/DoubleArrayTest.java | 60 ++++ .../utils/collections/IntArrayTest.java | 76 +++++ docs/gitbook/tips/rt_prediction.md | 4 +- 19 files changed, 1476 insertions(+), 43 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/054a697e/core/src/main/java/hivemall/matrix/CSRMatrixBuilder.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/matrix/CSRMatrixBuilder.java b/core/src/main/java/hivemall/matrix/CSRMatrixBuilder.java new file mode 100644 index 0000000..d2deda1 --- /dev/null +++ b/core/src/main/java/hivemall/matrix/CSRMatrixBuilder.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.matrix; + +import hivemall.utils.collections.DoubleArrayList; +import hivemall.utils.collections.IntArrayList; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +/** + * Compressed Sparse Row Matrix. + * + * @link http://netlib.org/linalg/html_templates/node91.html#SECTION00931100000000000000 + * @link http://www.cs.colostate.edu/~mcrob/toolbox/c++/sparseMatrix/sparse_matrix_compression.html + */ +public final class CSRMatrixBuilder extends MatrixBuilder { + + @Nonnull + private final IntArrayList rowPointers; + @Nonnull + private final IntArrayList columnIndices; + @Nonnull + private final DoubleArrayList values; + + private int maxNumColumns; + + public CSRMatrixBuilder(int initSize) { + super(); + this.rowPointers = new IntArrayList(initSize + 1); + rowPointers.add(0); + this.columnIndices = new IntArrayList(initSize); + this.values = new DoubleArrayList(initSize); + this.maxNumColumns = 0; + } + + @Override + public CSRMatrixBuilder nextRow() { + int ptr = values.size(); + rowPointers.add(ptr); + return this; + } + + @Override + public CSRMatrixBuilder nextColumn(@Nonnegative int col, double value) { + if (value == 0.d) { + return this; + } + + columnIndices.add(col); + values.add(value); + this.maxNumColumns = Math.max(col + 1, maxNumColumns); + return this; + } + + @Override + public Matrix buildMatrix(boolean readOnly) { + if (!readOnly) { + throw new UnsupportedOperationException("Only readOnly matrix is supported"); + } + + ReadOnlyCSRMatrix matrix = new ReadOnlyCSRMatrix(rowPointers.toArray(true), + columnIndices.toArray(true), values.toArray(true), maxNumColumns); + return matrix; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/054a697e/core/src/main/java/hivemall/matrix/DenseMatrixBuilder.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/matrix/DenseMatrixBuilder.java b/core/src/main/java/hivemall/matrix/DenseMatrixBuilder.java new file mode 100644 index 0000000..f70616e --- /dev/null +++ b/core/src/main/java/hivemall/matrix/DenseMatrixBuilder.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.matrix; + +import hivemall.utils.collections.SparseDoubleArray; + +import java.util.ArrayList; +import java.util.List; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +public final class DenseMatrixBuilder extends MatrixBuilder { + + @Nonnull + private final List<double[]> rows; + private int maxNumColumns; + + @Nonnull + private final SparseDoubleArray rowProbe; + + public DenseMatrixBuilder(int initSize) { + super(); + this.rows = new ArrayList<double[]>(initSize); + this.maxNumColumns = 0; + this.rowProbe = new SparseDoubleArray(32); + } + + @Override + public MatrixBuilder nextColumn(@Nonnegative final int col, final double value) { + if (value == 0.d) { + return this; + } + rowProbe.put(col, value); + return this; + } + + @Override + public MatrixBuilder nextRow() { + double[] row = rowProbe.toArray(); + rowProbe.clear(); + nextRow(row); + return this; + } + + @Override + public void nextRow(@Nonnull double[] row) { + rows.add(row); + this.maxNumColumns = Math.max(row.length, maxNumColumns); + } + + @Override + public Matrix buildMatrix(boolean readOnly) { + if (!readOnly) { + throw new UnsupportedOperationException("Only readOnly matrix is supported"); + } + + int numRows = rows.size(); + double[][] data = rows.toArray(new double[numRows][]); + return new ReadOnlyDenseMatrix2d(data, maxNumColumns); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/054a697e/core/src/main/java/hivemall/matrix/Matrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/matrix/Matrix.java b/core/src/main/java/hivemall/matrix/Matrix.java new file mode 100644 index 0000000..8bbb6c5 --- /dev/null +++ b/core/src/main/java/hivemall/matrix/Matrix.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.matrix; + +import javax.annotation.Nonnegative; + +public abstract class Matrix { + + private double defaultValue; + + public Matrix() { + this.defaultValue = 0.d; + } + + public abstract boolean readOnly(); + + public void setDefaultValue(double value) { + this.defaultValue = value; + } + + @Nonnegative + public abstract int numRows(); + + @Nonnegative + public abstract int numColumns(); + + @Nonnegative + public abstract int numColumns(@Nonnegative int row); + + /** + * @throws IndexOutOfBoundsException + */ + public final double get(@Nonnegative final int row, @Nonnegative final int col) { + return get(row, col, defaultValue); + } + + /** + * @throws IndexOutOfBoundsException + */ + public abstract double get(@Nonnegative int row, @Nonnegative int col, double defaultValue); + + /** + * @throws IndexOutOfBoundsException + * @throws UnsupportedOperationException + */ + public abstract void set(@Nonnegative int row, @Nonnegative int col, double value); + + /** + * @throws IndexOutOfBoundsException + * @throws UnsupportedOperationException + */ + public abstract double getAndSet(@Nonnegative int row, @Nonnegative final int col, double value); + + protected static final void checkRowIndex(final int row, final int numRows) { + if (row < 0 || row >= numRows) { + throw new IndexOutOfBoundsException("Row index " + row + " out of bounds " + numRows); + } + } + + protected static final void checkColIndex(final int col, final int numColumns) { + if (col < 0 || col >= numColumns) { + throw new IndexOutOfBoundsException("Col index " + col + " out of bounds " + numColumns); + } + } + + protected static final void checkIndex(final int row, final int col, final int numRows, + final int numColumns) { + if (row < 0 || row >= numRows) { + throw new IndexOutOfBoundsException("Row index " + row + " out of bounds " + numRows); + } + if (col < 0 || col >= numColumns) { + throw new IndexOutOfBoundsException("Col index " + col + " out of bounds " + numColumns); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/054a697e/core/src/main/java/hivemall/matrix/MatrixBuilder.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/matrix/MatrixBuilder.java b/core/src/main/java/hivemall/matrix/MatrixBuilder.java new file mode 100644 index 0000000..e4d6233 --- /dev/null +++ b/core/src/main/java/hivemall/matrix/MatrixBuilder.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.matrix; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +public abstract class MatrixBuilder { + + public MatrixBuilder() {} + + public void nextRow(@Nonnull final double[] row) { + for (int col = 0; col < row.length; col++) { + nextColumn(col, row[col]); + } + nextRow(); + } + + public void nextRow(@Nonnull final String[] row) { + for (String col : row) { + if (col == null) { + continue; + } + nextColumn(col); + } + nextRow(); + } + + @Nonnull + public abstract MatrixBuilder nextRow(); + + @Nonnull + public abstract MatrixBuilder nextColumn(@Nonnegative int col, double value); + + /** + * @throws IllegalArgumentException + * @throws NumberFormatException + */ + @Nonnull + public MatrixBuilder nextColumn(@Nonnull final String col) { + final int pos = col.indexOf(':'); + if (pos == 0) { + throw new IllegalArgumentException("Invalid feature value representation: " + col); + } + + final String feature; + final double value; + if (pos > 0) { + feature = col.substring(0, pos); + String s2 = col.substring(pos + 1); + value = Double.parseDouble(s2); + } else { + feature = col; + value = 1.d; + } + + if (feature.indexOf(':') != -1) { + throw new IllegalArgumentException("Invaliad feature format `<index>:<value>`: " + col); + } + + int colIndex = Integer.parseInt(feature); + if (colIndex < 0) { + throw new IllegalArgumentException("Col index MUST be greather than or equals to 0: " + + colIndex); + } + + return nextColumn(colIndex, value); + } + + @Nonnull + public abstract Matrix buildMatrix(boolean readOnly); + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/054a697e/core/src/main/java/hivemall/matrix/ReadOnlyCSRMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/matrix/ReadOnlyCSRMatrix.java b/core/src/main/java/hivemall/matrix/ReadOnlyCSRMatrix.java new file mode 100644 index 0000000..1c7a9a1 --- /dev/null +++ b/core/src/main/java/hivemall/matrix/ReadOnlyCSRMatrix.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.matrix; + +import hivemall.utils.lang.Preconditions; + +import java.util.Arrays; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +/** + * Read-only CSR Matrix. + * + * @see http://netlib.org/linalg/html_templates/node91.html#SECTION00931100000000000000 + */ +public final class ReadOnlyCSRMatrix extends Matrix { + + @Nonnull + private final int[] rowPointers; + @Nonnull + private final int[] columnIndices; + @Nonnull + private final double[] values; + + @Nonnegative + private final int numRows; + @Nonnegative + private final int numColumns; + + public ReadOnlyCSRMatrix(@Nonnull int[] rowPointers, @Nonnull int[] columnIndices, + @Nonnull double[] values, @Nonnegative int numColumns) { + super(); + Preconditions.checkArgument(rowPointers.length >= 1, + "rowPointers must be greather than 0: " + rowPointers.length); + Preconditions.checkArgument(columnIndices.length == values.length, "#columnIndices (" + + columnIndices.length + ") must be equals to #values (" + values.length + ")"); + this.rowPointers = rowPointers; + this.columnIndices = columnIndices; + this.values = values; + this.numRows = rowPointers.length - 1; + this.numColumns = numColumns; + } + + @Override + public boolean readOnly() { + return true; + } + + @Override + public int numRows() { + return numRows; + } + + @Override + public int numColumns() { + return numColumns; + } + + @Override + public int numColumns(@Nonnegative final int row) { + checkRowIndex(row, numRows); + + int columns = rowPointers[row + 1] - rowPointers[row]; + return columns; + } + + @Override + public double get(@Nonnegative final int row, @Nonnegative final int col, + final double defaultValue) { + checkIndex(row, col, numRows, numColumns); + + final int index = getIndex(row, col); + if (index < 0) { + return defaultValue; + } + return values[index]; + } + + @Override + public double getAndSet(@Nonnegative final int row, @Nonnegative final int col, + final double value) { + checkIndex(row, col, numRows, numColumns); + + final int index = getIndex(row, col); + if (index < 0) { + throw new UnsupportedOperationException("Cannot update value in row " + row + ", col " + + col); + } + + double old = values[index]; + values[index] = value; + return old; + } + + @Override + public void set(@Nonnegative final int row, @Nonnegative final int col, final double value) { + checkIndex(row, col, numRows, numColumns); + + final int index = getIndex(row, col); + if (index < 0) { + throw new UnsupportedOperationException("Cannot update value in row " + row + ", col " + + col); + } + values[index] = value; + } + + private int getIndex(@Nonnegative final int row, @Nonnegative final int col) { + int leftIn = rowPointers[row]; + int rightEx = rowPointers[row + 1]; + final int index = Arrays.binarySearch(columnIndices, leftIn, rightEx, col); + if (index >= 0 && index >= values.length) { + throw new IndexOutOfBoundsException("Value index " + index + " out of range " + + values.length); + } + return index; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/054a697e/core/src/main/java/hivemall/matrix/ReadOnlyDenseMatrix2d.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/matrix/ReadOnlyDenseMatrix2d.java b/core/src/main/java/hivemall/matrix/ReadOnlyDenseMatrix2d.java new file mode 100644 index 0000000..040fef8 --- /dev/null +++ b/core/src/main/java/hivemall/matrix/ReadOnlyDenseMatrix2d.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.matrix; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +public final class ReadOnlyDenseMatrix2d extends Matrix { + + @Nonnull + private final double[][] data; + + @Nonnegative + private final int numRows; + @Nonnegative + private final int numColumns; + + public ReadOnlyDenseMatrix2d(@Nonnull double[][] data, @Nonnegative int numColumns) { + this.data = data; + this.numRows = data.length; + this.numColumns = numColumns; + } + + @Override + public boolean readOnly() { + return true; + } + + @Override + public void setDefaultValue(double value) { + throw new UnsupportedOperationException("The defaultValue of a DenseMatrix is fixed to 0.d"); + } + + @Override + public int numRows() { + return numRows; + } + + @Override + public int numColumns() { + return numColumns; + } + + @Override + public int numColumns(@Nonnegative final int row) { + checkRowIndex(row, numRows); + + return data[row].length; + } + + @Override + public double get(@Nonnegative final int row, @Nonnegative final int col, + final double defaultValue) { + checkIndex(row, col, numRows, numColumns); + + final double[] rowData = data[row]; + if (col >= rowData.length) { + return defaultValue; + } + return rowData[col]; + } + + @Override + public double getAndSet(@Nonnegative final int row, @Nonnegative final int col, + final double value) { + checkIndex(row, col, numRows, numColumns); + + final double[] rowData = data[row]; + checkColIndex(col, rowData.length); + + double old = rowData[col]; + rowData[col] = value; + return old; + } + + @Override + public void set(@Nonnegative final int row, @Nonnegative final int col, final double value) { + checkIndex(row, col, numRows, numColumns); + + final double[] rowData = data[row]; + checkColIndex(col, rowData.length); + + rowData[col] = value; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/054a697e/core/src/main/java/hivemall/model/FeatureValue.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/FeatureValue.java b/core/src/main/java/hivemall/model/FeatureValue.java index 7ff3383..39fadaf 100644 --- a/core/src/main/java/hivemall/model/FeatureValue.java +++ b/core/src/main/java/hivemall/model/FeatureValue.java @@ -47,7 +47,7 @@ public final class FeatureValue { public <T> T getFeature() { return (T) feature; } - + public int getFeatureAsInt() { Preconditions.checkNotNull(feature); Preconditions.checkArgument(feature instanceof Integer); @@ -93,7 +93,6 @@ public final class FeatureValue { @Nullable public static FeatureValue parse(@Nonnull final String s, final boolean mhash) throws IllegalArgumentException { - assert (s != null); final int pos = s.indexOf(':'); if (pos == 0) { throw new IllegalArgumentException("Invalid feature value representation: " + s); @@ -122,7 +121,6 @@ public final class FeatureValue { @Nonnull public static FeatureValue parseFeatureAsString(@Nonnull final String s) throws IllegalArgumentException { - assert (s != null); final int pos = s.indexOf(':'); if (pos == 0) { throw new IllegalArgumentException("Invalid feature value representation: " + s); @@ -142,17 +140,12 @@ public final class FeatureValue { } public static void parseFeatureAsString(@Nonnull final Text t, @Nonnull final FeatureValue probe) { - assert (t != null); - String s = t.toString(); parseFeatureAsString(s, probe); } public static void parseFeatureAsString(@Nonnull final String s, @Nonnull final FeatureValue probe) throws IllegalArgumentException { - assert (s != null); - assert (probe != null); - final int pos = s.indexOf(':'); if (pos == 0) { throw new IllegalArgumentException("Invalid feature value representation: " + s); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/054a697e/core/src/main/java/hivemall/utils/collections/DoubleArray.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/DoubleArray.java b/core/src/main/java/hivemall/utils/collections/DoubleArray.java new file mode 100644 index 0000000..a7dfa81 --- /dev/null +++ b/core/src/main/java/hivemall/utils/collections/DoubleArray.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.collections; + +import java.io.Serializable; + +import javax.annotation.Nonnull; + +public interface DoubleArray extends Serializable { + + public double get(int key); + + public double get(int key, double valueIfKeyNotFound); + + public void put(int key, double value); + + public int size(); + + public int keyAt(int index); + + @Nonnull + public double[] toArray(); + + @Nonnull + public double[] toArray(boolean copy); + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/054a697e/core/src/main/java/hivemall/utils/collections/DoubleArrayList.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/DoubleArrayList.java b/core/src/main/java/hivemall/utils/collections/DoubleArrayList.java index 1a37845..afdc251 100644 --- a/core/src/main/java/hivemall/utils/collections/DoubleArrayList.java +++ b/core/src/main/java/hivemall/utils/collections/DoubleArrayList.java @@ -18,11 +18,13 @@ */ package hivemall.utils.collections; +import java.io.Closeable; import java.io.Serializable; -public final class DoubleArrayList implements Serializable { - private static final long serialVersionUID = -8155789759545975413L; +import javax.annotation.Nonnull; +public final class DoubleArrayList implements Serializable, Closeable { + private static final long serialVersionUID = -8155789759545975413L; public static final int DEFAULT_CAPACITY = 12; /** array entity */ @@ -126,9 +128,18 @@ public final class DoubleArrayList implements Serializable { used = 0; } + @Nonnull public double[] toArray() { + return toArray(false); + } + + @Nonnull + public double[] toArray(boolean close) { final double[] newArray = new double[used]; System.arraycopy(data, 0, newArray, 0, used); + if (close) { + close(); + } return newArray; } @@ -149,4 +160,9 @@ public final class DoubleArrayList implements Serializable { buf.append(']'); return buf.toString(); } + + @Override + public void close() { + this.data = null; + } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/054a697e/core/src/main/java/hivemall/utils/collections/FixedIntArray.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/FixedIntArray.java b/core/src/main/java/hivemall/utils/collections/FixedIntArray.java index d4acdc1..927ee83 100644 --- a/core/src/main/java/hivemall/utils/collections/FixedIntArray.java +++ b/core/src/main/java/hivemall/utils/collections/FixedIntArray.java @@ -18,8 +18,13 @@ */ package hivemall.utils.collections; +import java.util.Arrays; + import javax.annotation.Nonnull; +/** + * A fixed INT array that has keys greater than or equals to 0. + */ public final class FixedIntArray implements IntArray { private static final long serialVersionUID = -1450212841013810240L; @@ -65,4 +70,18 @@ public final class FixedIntArray implements IntArray { return index; } + @Override + public int[] toArray() { + return toArray(true); + } + + @Override + public int[] toArray(boolean copy) { + if (copy) { + return Arrays.copyOf(array, size); + } else { + return array; + } + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/054a697e/core/src/main/java/hivemall/utils/collections/IntArray.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/IntArray.java b/core/src/main/java/hivemall/utils/collections/IntArray.java index a530efe..cb6b0b8 100644 --- a/core/src/main/java/hivemall/utils/collections/IntArray.java +++ b/core/src/main/java/hivemall/utils/collections/IntArray.java @@ -20,16 +20,24 @@ package hivemall.utils.collections; import java.io.Serializable; +import javax.annotation.Nonnull; + public interface IntArray extends Serializable { - public abstract int get(int key); + public int get(int key); - public abstract int get(int key, int valueIfKeyNotFound); + public int get(int key, int valueIfKeyNotFound); - public abstract void put(int key, int value); + public void put(int key, int value); - public abstract int size(); + public int size(); public int keyAt(int index); + @Nonnull + public int[] toArray(); + + @Nonnull + public int[] toArray(boolean copy); + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/054a697e/core/src/main/java/hivemall/utils/collections/IntArrayList.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/IntArrayList.java b/core/src/main/java/hivemall/utils/collections/IntArrayList.java index c0d79de..0716ca8 100644 --- a/core/src/main/java/hivemall/utils/collections/IntArrayList.java +++ b/core/src/main/java/hivemall/utils/collections/IntArrayList.java @@ -20,11 +20,13 @@ package hivemall.utils.collections; import hivemall.utils.lang.ArrayUtils; +import java.io.Closeable; import java.io.Serializable; -public final class IntArrayList implements Serializable { - private static final long serialVersionUID = -2147675120406747488L; +import javax.annotation.Nonnull; +public final class IntArrayList implements Serializable, Closeable { + private static final long serialVersionUID = -2147675120406747488L; public static final int DEFAULT_CAPACITY = 12; /** array entity */ @@ -141,9 +143,18 @@ public final class IntArrayList implements Serializable { used = 0; } + @Nonnull public int[] toArray() { + return toArray(false); + } + + @Nonnull + public int[] toArray(boolean close) { final int[] newArray = new int[used]; System.arraycopy(data, 0, newArray, 0, used); + if (close) { + close(); + } return newArray; } @@ -164,4 +175,9 @@ public final class IntArrayList implements Serializable { buf.append(']'); return buf.toString(); } + + @Override + public void close() { + this.data = null; + } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/054a697e/core/src/main/java/hivemall/utils/collections/SparseDoubleArray.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/SparseDoubleArray.java b/core/src/main/java/hivemall/utils/collections/SparseDoubleArray.java new file mode 100644 index 0000000..c4dbbb5 --- /dev/null +++ b/core/src/main/java/hivemall/utils/collections/SparseDoubleArray.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.collections; + +import hivemall.utils.lang.ArrayUtils; +import hivemall.utils.lang.Preconditions; + +import java.util.Arrays; + +import javax.annotation.Nonnull; + +public final class SparseDoubleArray implements DoubleArray { + private static final long serialVersionUID = -2814248784231540118L; + + @Nonnull + private int[] mKeys; + @Nonnull + private double[] mValues; + private int mSize; + + public SparseDoubleArray() { + this(10); + } + + public SparseDoubleArray(int initialCapacity) { + mKeys = new int[initialCapacity]; + mValues = new double[initialCapacity]; + mSize = 0; + } + + private SparseDoubleArray(@Nonnull int[] mKeys, @Nonnull double[] mValues, int mSize) { + this.mKeys = mKeys; + this.mValues = mValues; + this.mSize = mSize; + } + + @Nonnull + public SparseDoubleArray deepCopy() { + int[] newKeys = new int[mSize]; + double[] newValues = new double[mSize]; + System.arraycopy(mKeys, 0, newKeys, 0, mSize); + System.arraycopy(mValues, 0, newValues, 0, mSize); + return new SparseDoubleArray(newKeys, newValues, mSize); + } + + @Override + public double get(int key) { + return get(key, 0); + } + + @Override + public double get(int key, double valueIfKeyNotFound) { + int i = Arrays.binarySearch(mKeys, 0, mSize, key); + if (i < 0) { + return valueIfKeyNotFound; + } else { + return mValues[i]; + } + } + + public void delete(int key) { + int i = Arrays.binarySearch(mKeys, 0, mSize, key); + if (i >= 0) { + removeAt(i); + } + } + + public void removeAt(int index) { + System.arraycopy(mKeys, index + 1, mKeys, index, mSize - (index + 1)); + System.arraycopy(mValues, index + 1, mValues, index, mSize - (index + 1)); + mSize--; + } + + @Override + public void put(int key, double value) { + int i = Arrays.binarySearch(mKeys, 0, mSize, key); + if (i >= 0) { + mValues[i] = value; + } else { + i = ~i; + mKeys = ArrayUtils.insert(mKeys, mSize, i, key); + mValues = ArrayUtils.insert(mValues, mSize, i, value); + mSize++; + } + } + + public void increment(int key, double value) { + int i = Arrays.binarySearch(mKeys, 0, mSize, key); + if (i >= 0) { + mValues[i] += value; + } else { + i = ~i; + mKeys = ArrayUtils.insert(mKeys, mSize, i, key); + mValues = ArrayUtils.insert(mValues, mSize, i, value); + mSize++; + } + } + + @Override + public int size() { + return mSize; + } + + @Override + public int keyAt(int index) { + return mKeys[index]; + } + + public double valueAt(int index) { + return mValues[index]; + } + + public void setValueAt(int index, double value) { + mValues[index] = value; + } + + public int indexOfKey(int key) { + return Arrays.binarySearch(mKeys, 0, mSize, key); + } + + public int indexOfValue(double value) { + for (int i = 0; i < mSize; i++) { + if (mValues[i] == value) { + return i; + } + } + return -1; + } + + public void clear() { + clear(true); + } + + public void clear(boolean zeroFill) { + mSize = 0; + if (zeroFill) { + Arrays.fill(mKeys, 0); + Arrays.fill(mValues, 0.d); + } + } + + public void append(int key, double value) { + if (mSize != 0 && key <= mKeys[mSize - 1]) { + put(key, value); + return; + } + mKeys = ArrayUtils.append(mKeys, mSize, key); + mValues = ArrayUtils.append(mValues, mSize, value); + mSize++; + } + + @Override + public double[] toArray() { + return toArray(true); + } + + @Override + public double[] toArray(boolean copy) { + if (mSize == 0) { + return new double[0]; + } + + int last = mKeys[mSize - 1]; + final double[] array = new double[last + 1]; + for (int i = 0; i < mSize; i++) { + int k = mKeys[i]; + double v = mValues[i]; + Preconditions.checkArgument(k >= 0, "Negative key is not allowed for toArray(): " + k); + array[k] = v; + } + return array; + } + + @Override + public String toString() { + if (size() <= 0) { + return "{}"; + } + + StringBuilder buffer = new StringBuilder(mSize * 28); + buffer.append('{'); + for (int i = 0; i < mSize; i++) { + if (i > 0) { + buffer.append(", "); + } + int key = keyAt(i); + buffer.append(key); + buffer.append('='); + double value = valueAt(i); + buffer.append(value); + } + buffer.append('}'); + return buffer.toString(); + } + + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/054a697e/core/src/main/java/hivemall/utils/collections/SparseIntArray.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/SparseIntArray.java b/core/src/main/java/hivemall/utils/collections/SparseIntArray.java index 2c60258..7a4ba69 100644 --- a/core/src/main/java/hivemall/utils/collections/SparseIntArray.java +++ b/core/src/main/java/hivemall/utils/collections/SparseIntArray.java @@ -19,9 +19,12 @@ package hivemall.utils.collections; import hivemall.utils.lang.ArrayUtils; +import hivemall.utils.lang.Preconditions; import java.util.Arrays; +import javax.annotation.Nonnull; + public final class SparseIntArray implements IntArray { private static final long serialVersionUID = -2814248784231540118L; @@ -138,7 +141,15 @@ public final class SparseIntArray implements IntArray { } public void clear() { + clear(true); + } + + public void clear(boolean zeroFill) { mSize = 0; + if (zeroFill) { + Arrays.fill(mKeys, 0); + Arrays.fill(mValues, 0); + } } public void append(int key, int value) { @@ -151,6 +162,28 @@ public final class SparseIntArray implements IntArray { mSize++; } + @Nonnull + public int[] toArray() { + return toArray(true); + } + + @Override + public int[] toArray(boolean copy) { + if (mSize == 0) { + return new int[0]; + } + + int last = mKeys[mSize - 1]; + final int[] array = new int[last + 1]; + for (int i = 0; i < mSize; i++) { + int k = mKeys[i]; + int v = mValues[i]; + Preconditions.checkArgument(k >= 0, "Negative key is not allowed for toArray(): " + k); + array[k] = v; + } + return array; + } + @Override public String toString() { if (size() <= 0) { @@ -172,4 +205,6 @@ public final class SparseIntArray implements IntArray { buffer.append('}'); return buffer.toString(); } + + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/054a697e/core/src/main/java/hivemall/utils/lang/ArrayUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java index 521dcbd..24ed7fc 100644 --- a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java +++ b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java @@ -24,19 +24,20 @@ import java.util.List; import java.util.Random; import javax.annotation.Nonnull; +import javax.annotation.Nullable; public final class ArrayUtils { /** - * The index value when an element is not found in a list or array: <code>-1</code>. This value - * is returned by methods in this class and can also be used in comparisons with values returned - * by various method from {@link java.util.List}. + * The index value when an element is not found in a list or array: <code>-1</code>. This value is returned by methods in this class and can also + * be used in comparisons with values returned by various method from {@link java.util.List}. */ public static final int INDEX_NOT_FOUND = -1; private ArrayUtils() {} - public static double[] set(double[] src, final int index, final double value) { + @Nonnull + public static double[] set(@Nonnull double[] src, final int index, final double value) { if (index >= src.length) { src = Arrays.copyOf(src, src.length * 2); } @@ -44,7 +45,8 @@ public final class ArrayUtils { return src; } - public static <T> T[] set(T[] src, final int index, final T value) { + @Nonnull + public static <T> T[] set(@Nonnull T[] src, final int index, final T value) { if (index >= src.length) { src = Arrays.copyOf(src, src.length * 2); } @@ -52,7 +54,8 @@ public final class ArrayUtils { return src; } - public static float[] toArray(final List<Float> lst) { + @Nonnull + public static float[] toArray(@Nonnull final List<Float> lst) { final int ndim = lst.size(); final float[] ary = new float[ndim]; int i = 0; @@ -62,7 +65,8 @@ public final class ArrayUtils { return ary; } - public static Integer[] toObject(final int[] array) { + @Nonnull + public static Integer[] toObject(@Nonnull final int[] array) { final Integer[] result = new Integer[array.length]; for (int i = 0; i < array.length; i++) { result[i] = array[i]; @@ -70,12 +74,14 @@ public final class ArrayUtils { return result; } - public static List<Integer> toList(final int[] array) { + @Nonnull + public static List<Integer> toList(@Nonnull final int[] array) { Integer[] v = toObject(array); return Arrays.asList(v); } - public static Long[] toObject(final long[] array) { + @Nonnull + public static Long[] toObject(@Nonnull final long[] array) { final Long[] result = new Long[array.length]; for (int i = 0; i < array.length; i++) { result[i] = array[i]; @@ -83,12 +89,14 @@ public final class ArrayUtils { return result; } - public static List<Long> toList(final long[] array) { + @Nonnull + public static List<Long> toList(@Nonnull final long[] array) { Long[] v = toObject(array); return Arrays.asList(v); } - public static Float[] toObject(final float[] array) { + @Nonnull + public static Float[] toObject(@Nonnull final float[] array) { final Float[] result = new Float[array.length]; for (int i = 0; i < array.length; i++) { result[i] = array[i]; @@ -96,12 +104,14 @@ public final class ArrayUtils { return result; } - public static List<Float> toList(final float[] array) { + @Nonnull + public static List<Float> toList(@Nonnull final float[] array) { Float[] v = toObject(array); return Arrays.asList(v); } - public static Double[] toObject(final double[] array) { + @Nonnull + public static Double[] toObject(@Nonnull final double[] array) { final Double[] result = new Double[array.length]; for (int i = 0; i < array.length; i++) { result[i] = array[i]; @@ -109,20 +119,21 @@ public final class ArrayUtils { return result; } - public static List<Double> toList(final double[] array) { + @Nonnull + public static List<Double> toList(@Nonnull final double[] array) { Double[] v = toObject(array); return Arrays.asList(v); } - public static <T> void shuffle(final T[] array) { + public static <T> void shuffle(@Nonnull final T[] array) { shuffle(array, array.length); } - public static <T> void shuffle(final T[] array, final Random rnd) { + public static <T> void shuffle(@Nonnull final T[] array, final Random rnd) { shuffle(array, array.length, rnd); } - public static <T> void shuffle(final T[] array, final int size) { + public static <T> void shuffle(@Nonnull final T[] array, final int size) { Random rnd = new Random(); shuffle(array, size, rnd); } @@ -159,7 +170,9 @@ public final class ArrayUtils { arr[j] = tmp; } - public static Object[] subarray(Object[] array, int startIndexInclusive, int endIndexExclusive) { + @Nullable + public static Object[] subarray(@Nullable final Object[] array, int startIndexInclusive, + int endIndexExclusive) { if (array == null) { return null; } @@ -179,13 +192,14 @@ public final class ArrayUtils { return subarray; } - public static void fill(final float[] a, final Random rand) { + public static void fill(@Nonnull final float[] a, @Nonnull final Random rand) { for (int i = 0, len = a.length; i < len; i++) { a[i] = rand.nextFloat(); } } - public static int indexOf(final int[] array, final int valueToFind, int startIndex, int endIndex) { + public static int indexOf(@Nonnull final int[] array, final int valueToFind, + final int startIndex, final int endIndex) { if (array == null) { return INDEX_NOT_FOUND; } @@ -201,20 +215,21 @@ public final class ArrayUtils { return INDEX_NOT_FOUND; } - public static byte[] copyOf(final byte[] original, final int newLength) { + @Nonnull + public static byte[] copyOf(@Nonnull final byte[] original, final int newLength) { final byte[] copy = new byte[newLength]; System.arraycopy(original, 0, copy, 0, Math.min(original.length, newLength)); return copy; } - public static int[] copyOf(final int[] src) { + public static int[] copyOf(@Nonnull final int[] src) { int len = src.length; int[] dest = new int[len]; System.arraycopy(src, 0, dest, 0, len); return dest; } - public static void copy(final int[] src, final int[] dest) { + public static void copy(@Nonnull final int[] src, @Nonnull final int[] dest) { if (src.length != dest.length) { throw new IllegalArgumentException("src.legnth '" + src.length + "' != dest.length '" + dest.length + "'"); @@ -222,7 +237,8 @@ public final class ArrayUtils { System.arraycopy(src, 0, dest, 0, src.length); } - public static int[] append(int[] array, int currentSize, int element) { + @Nonnull + public static int[] append(@Nonnull int[] array, final int currentSize, final int element) { if (currentSize + 1 > array.length) { int[] newArray = new int[currentSize * 2]; System.arraycopy(array, 0, newArray, 0, currentSize); @@ -232,7 +248,21 @@ public final class ArrayUtils { return array; } - public static int[] insert(int[] array, int currentSize, int index, int element) { + @Nonnull + public static double[] append(@Nonnull double[] array, final int currentSize, + final double element) { + if (currentSize + 1 > array.length) { + double[] newArray = new double[currentSize * 2]; + System.arraycopy(array, 0, newArray, 0, currentSize); + array = newArray; + } + array[currentSize] = element; + return array; + } + + @Nonnull + public static int[] insert(@Nonnull final int[] array, final int currentSize, final int index, + final int element) { if (currentSize + 1 <= array.length) { System.arraycopy(array, index, array, index + 1, currentSize - index); array[index] = element; @@ -245,6 +275,21 @@ public final class ArrayUtils { return newArray; } + @Nonnull + public static double[] insert(@Nonnull final double[] array, final int currentSize, + final int index, final double element) { + if (currentSize + 1 <= array.length) { + System.arraycopy(array, index, array, index + 1, currentSize - index); + array[index] = element; + return array; + } + double[] newArray = new double[currentSize * 2]; + System.arraycopy(array, 0, newArray, 0, index); + newArray[index] = element; + System.arraycopy(array, index, newArray, index + 1, array.length - index); + return newArray; + } + public static boolean equals(@Nonnull final float[] array, final float value) { for (int i = 0, size = array.length; i < size; i++) { if (array[i] != value) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/054a697e/core/src/test/java/hivemall/matrix/MatrixBuilderTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/matrix/MatrixBuilderTest.java b/core/src/test/java/hivemall/matrix/MatrixBuilderTest.java new file mode 100644 index 0000000..5545631 --- /dev/null +++ b/core/src/test/java/hivemall/matrix/MatrixBuilderTest.java @@ -0,0 +1,329 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.matrix; + +import org.junit.Assert; +import org.junit.Test; + +public class MatrixBuilderTest { + + @Test + public void testReadOnlyCSRMatrix() { + Matrix matrix = csrMatrix(); + Assert.assertEquals(6, matrix.numRows()); + Assert.assertEquals(6, matrix.numColumns()); + Assert.assertEquals(4, matrix.numColumns(0)); + Assert.assertEquals(2, matrix.numColumns(1)); + Assert.assertEquals(4, matrix.numColumns(2)); + Assert.assertEquals(2, matrix.numColumns(3)); + Assert.assertEquals(1, matrix.numColumns(4)); + Assert.assertEquals(1, matrix.numColumns(5)); + + Assert.assertEquals(11d, matrix.get(0, 0), 0.d); + Assert.assertEquals(12d, matrix.get(0, 1), 0.d); + Assert.assertEquals(13d, matrix.get(0, 2), 0.d); + Assert.assertEquals(14d, matrix.get(0, 3), 0.d); + Assert.assertEquals(22d, matrix.get(1, 1), 0.d); + Assert.assertEquals(23d, matrix.get(1, 2), 0.d); + Assert.assertEquals(33d, matrix.get(2, 2), 0.d); + Assert.assertEquals(34d, matrix.get(2, 3), 0.d); + Assert.assertEquals(35d, matrix.get(2, 4), 0.d); + Assert.assertEquals(36d, matrix.get(2, 5), 0.d); + Assert.assertEquals(44d, matrix.get(3, 3), 0.d); + Assert.assertEquals(45d, matrix.get(3, 4), 0.d); + Assert.assertEquals(56d, matrix.get(4, 5), 0.d); + Assert.assertEquals(66d, matrix.get(5, 5), 0.d); + + Assert.assertEquals(0.d, matrix.get(5, 4), 0.d); + Assert.assertEquals(-1.d, matrix.get(5, 4, -1.d), 0.d); + + matrix.setDefaultValue(Double.NaN); + Assert.assertEquals(Double.NaN, matrix.get(5, 4), 0.d); + } + + @Test + public void testReadOnlyCSRMatrixFromLibSVM() { + Matrix matrix = csrMatrixFromLibSVM(); + Assert.assertEquals(6, matrix.numRows()); + Assert.assertEquals(6, matrix.numColumns()); + Assert.assertEquals(4, matrix.numColumns(0)); + Assert.assertEquals(2, matrix.numColumns(1)); + Assert.assertEquals(4, matrix.numColumns(2)); + Assert.assertEquals(2, matrix.numColumns(3)); + Assert.assertEquals(1, matrix.numColumns(4)); + Assert.assertEquals(1, matrix.numColumns(5)); + + Assert.assertEquals(11d, matrix.get(0, 0), 0.d); + Assert.assertEquals(12d, matrix.get(0, 1), 0.d); + Assert.assertEquals(13d, matrix.get(0, 2), 0.d); + Assert.assertEquals(14d, matrix.get(0, 3), 0.d); + Assert.assertEquals(22d, matrix.get(1, 1), 0.d); + Assert.assertEquals(23d, matrix.get(1, 2), 0.d); + Assert.assertEquals(33d, matrix.get(2, 2), 0.d); + Assert.assertEquals(34d, matrix.get(2, 3), 0.d); + Assert.assertEquals(35d, matrix.get(2, 4), 0.d); + Assert.assertEquals(36d, matrix.get(2, 5), 0.d); + Assert.assertEquals(44d, matrix.get(3, 3), 0.d); + Assert.assertEquals(45d, matrix.get(3, 4), 0.d); + Assert.assertEquals(56d, matrix.get(4, 5), 0.d); + Assert.assertEquals(66d, matrix.get(5, 5), 0.d); + + Assert.assertEquals(0.d, matrix.get(5, 4), 0.d); + Assert.assertEquals(-1.d, matrix.get(5, 4, -1.d), 0.d); + + matrix.setDefaultValue(Double.NaN); + Assert.assertEquals(Double.NaN, matrix.get(5, 4), 0.d); + } + + + @Test + public void testReadOnlyCSRMatrixNoRow() { + CSRMatrixBuilder builder = new CSRMatrixBuilder(1024); + Matrix matrix = builder.buildMatrix(true); + Assert.assertEquals(0, matrix.numRows()); + Assert.assertEquals(0, matrix.numColumns()); + } + + @Test(expected = IndexOutOfBoundsException.class) + public void testReadOnlyCSRMatrixGetFail1() { + Matrix matrix = csrMatrix(); + matrix.get(7, 5); + } + + @Test(expected = IndexOutOfBoundsException.class) + public void testReadOnlyCSRMatrixGetFail2() { + Matrix matrix = csrMatrix(); + matrix.get(6, 7); + } + + @Test + public void testReadOnlyDenseMatrix2d() { + Matrix matrix = denseMatrix(); + Assert.assertEquals(6, matrix.numRows()); + Assert.assertEquals(6, matrix.numColumns()); + Assert.assertEquals(4, matrix.numColumns(0)); + Assert.assertEquals(3, matrix.numColumns(1)); + Assert.assertEquals(6, matrix.numColumns(2)); + Assert.assertEquals(5, matrix.numColumns(3)); + Assert.assertEquals(6, matrix.numColumns(4)); + Assert.assertEquals(6, matrix.numColumns(5)); + + Assert.assertEquals(11d, matrix.get(0, 0), 0.d); + Assert.assertEquals(12d, matrix.get(0, 1), 0.d); + Assert.assertEquals(13d, matrix.get(0, 2), 0.d); + Assert.assertEquals(14d, matrix.get(0, 3), 0.d); + Assert.assertEquals(22d, matrix.get(1, 1), 0.d); + Assert.assertEquals(23d, matrix.get(1, 2), 0.d); + Assert.assertEquals(33d, matrix.get(2, 2), 0.d); + Assert.assertEquals(34d, matrix.get(2, 3), 0.d); + Assert.assertEquals(35d, matrix.get(2, 4), 0.d); + Assert.assertEquals(36d, matrix.get(2, 5), 0.d); + Assert.assertEquals(44d, matrix.get(3, 3), 0.d); + Assert.assertEquals(45d, matrix.get(3, 4), 0.d); + Assert.assertEquals(56d, matrix.get(4, 5), 0.d); + Assert.assertEquals(66d, matrix.get(5, 5), 0.d); + + Assert.assertEquals(0.d, matrix.get(5, 4), 0.d); + + Assert.assertEquals(0.d, matrix.get(1, 0), 0.d); + Assert.assertEquals(0.d, matrix.get(1, 3), 0.d); + Assert.assertEquals(0.d, matrix.get(1, 0), 0.d); + } + + @Test + public void testReadOnlyDenseMatrix2dSparseInput() { + Matrix matrix = denseMatrixSparseInput(); + Assert.assertEquals(6, matrix.numRows()); + Assert.assertEquals(6, matrix.numColumns()); + Assert.assertEquals(4, matrix.numColumns(0)); + Assert.assertEquals(3, matrix.numColumns(1)); + Assert.assertEquals(6, matrix.numColumns(2)); + Assert.assertEquals(5, matrix.numColumns(3)); + Assert.assertEquals(6, matrix.numColumns(4)); + Assert.assertEquals(6, matrix.numColumns(5)); + + Assert.assertEquals(11d, matrix.get(0, 0), 0.d); + Assert.assertEquals(12d, matrix.get(0, 1), 0.d); + Assert.assertEquals(13d, matrix.get(0, 2), 0.d); + Assert.assertEquals(14d, matrix.get(0, 3), 0.d); + Assert.assertEquals(22d, matrix.get(1, 1), 0.d); + Assert.assertEquals(23d, matrix.get(1, 2), 0.d); + Assert.assertEquals(33d, matrix.get(2, 2), 0.d); + Assert.assertEquals(34d, matrix.get(2, 3), 0.d); + Assert.assertEquals(35d, matrix.get(2, 4), 0.d); + Assert.assertEquals(36d, matrix.get(2, 5), 0.d); + Assert.assertEquals(44d, matrix.get(3, 3), 0.d); + Assert.assertEquals(45d, matrix.get(3, 4), 0.d); + Assert.assertEquals(56d, matrix.get(4, 5), 0.d); + Assert.assertEquals(66d, matrix.get(5, 5), 0.d); + + Assert.assertEquals(0.d, matrix.get(5, 4), 0.d); + + Assert.assertEquals(0.d, matrix.get(1, 0), 0.d); + Assert.assertEquals(0.d, matrix.get(1, 3), 0.d); + Assert.assertEquals(0.d, matrix.get(1, 0), 0.d); + } + + @Test + public void testReadOnlyDenseMatrix2dFromLibSVM() { + Matrix matrix = denseMatrixFromLibSVM(); + Assert.assertEquals(6, matrix.numRows()); + Assert.assertEquals(6, matrix.numColumns()); + Assert.assertEquals(4, matrix.numColumns(0)); + Assert.assertEquals(3, matrix.numColumns(1)); + Assert.assertEquals(6, matrix.numColumns(2)); + Assert.assertEquals(5, matrix.numColumns(3)); + Assert.assertEquals(6, matrix.numColumns(4)); + Assert.assertEquals(6, matrix.numColumns(5)); + + Assert.assertEquals(11d, matrix.get(0, 0), 0.d); + Assert.assertEquals(12d, matrix.get(0, 1), 0.d); + Assert.assertEquals(13d, matrix.get(0, 2), 0.d); + Assert.assertEquals(14d, matrix.get(0, 3), 0.d); + Assert.assertEquals(22d, matrix.get(1, 1), 0.d); + Assert.assertEquals(23d, matrix.get(1, 2), 0.d); + Assert.assertEquals(33d, matrix.get(2, 2), 0.d); + Assert.assertEquals(34d, matrix.get(2, 3), 0.d); + Assert.assertEquals(35d, matrix.get(2, 4), 0.d); + Assert.assertEquals(36d, matrix.get(2, 5), 0.d); + Assert.assertEquals(44d, matrix.get(3, 3), 0.d); + Assert.assertEquals(45d, matrix.get(3, 4), 0.d); + Assert.assertEquals(56d, matrix.get(4, 5), 0.d); + Assert.assertEquals(66d, matrix.get(5, 5), 0.d); + + Assert.assertEquals(0.d, matrix.get(5, 4), 0.d); + + Assert.assertEquals(0.d, matrix.get(1, 0), 0.d); + Assert.assertEquals(0.d, matrix.get(1, 3), 0.d); + Assert.assertEquals(0.d, matrix.get(1, 0), 0.d); + } + + @Test + public void testReadOnlyDenseMatrix2dNoRow() { + Matrix matrix = new DenseMatrixBuilder(1024).buildMatrix(true); + Assert.assertEquals(0, matrix.numRows()); + Assert.assertEquals(0, matrix.numColumns()); + } + + @Test(expected = UnsupportedOperationException.class) + public void testReadOnlyDenseMatrix2dFailToChangeDefaultValue() { + Matrix matrix = denseMatrix(); + matrix.setDefaultValue(Double.NaN); + } + + @Test(expected = IndexOutOfBoundsException.class) + public void testReadOnlyDenseMatrix2dFailOutOfBound1() { + Matrix matrix = denseMatrix(); + matrix.get(7, 5); + } + + @Test(expected = IndexOutOfBoundsException.class) + public void testReadOnlyDenseMatrix2dFailOutOfBound2() { + Matrix matrix = denseMatrix(); + matrix.get(6, 7); + } + + private static Matrix csrMatrix() { + /* + 11 12 13 14 0 0 + 0 22 23 0 0 0 + 0 0 33 34 35 36 + 0 0 0 44 45 0 + 0 0 0 0 0 56 + 0 0 0 0 0 66 + */ + CSRMatrixBuilder builder = new CSRMatrixBuilder(1024); + builder.nextColumn(0, 11).nextColumn(1, 12).nextColumn(2, 13).nextColumn(3, 14).nextRow(); + builder.nextColumn(1, 22).nextColumn(2, 23).nextRow(); + builder.nextColumn(2, 33).nextColumn(3, 34).nextColumn(4, 35).nextColumn(5, 36).nextRow(); + builder.nextColumn(3, 44).nextColumn(4, 45).nextRow(); + builder.nextColumn(5, 56).nextRow(); + builder.nextColumn(5, 66).nextRow(); + return builder.buildMatrix(true); + } + + private static Matrix csrMatrixFromLibSVM() { + /* + 11 12 13 14 0 0 + 0 22 23 0 0 0 + 0 0 33 34 35 36 + 0 0 0 44 45 0 + 0 0 0 0 0 56 + 0 0 0 0 0 66 + */ + CSRMatrixBuilder builder = new CSRMatrixBuilder(1024); + builder.nextRow(new String[] {"0:11", "1:12", "2:13", "3:14"}); + builder.nextRow(new String[] {"1:22", "2:23"}); + builder.nextRow(new String[] {"2:33", "3:34", "4:35", "5:36"}); + builder.nextRow(new String[] {"3:44", "4:45"}); + builder.nextRow(new String[] {"5:56"}); + builder.nextRow(new String[] {"5:66"}); + return builder.buildMatrix(true); + } + + private static Matrix denseMatrix() { + /* + 11 12 13 14 0 0 + 0 22 23 0 0 0 + 0 0 33 34 35 36 + 0 0 0 44 45 0 + 0 0 0 0 0 56 + 0 0 0 0 0 66 + */ + DenseMatrixBuilder builder = new DenseMatrixBuilder(1024); + builder.nextRow(new double[] {11, 12, 13, 14}); + builder.nextRow(new double[] {0, 22, 23}); + builder.nextRow(new double[] {0, 0, 33, 34, 35, 36}); + builder.nextRow(new double[] {0, 0, 0, 44, 45}); + builder.nextRow(new double[] {0, 0, 0, 0, 0, 56}); + builder.nextRow(new double[] {0, 0, 0, 0, 0, 66}); + return builder.buildMatrix(true); + } + + private static Matrix denseMatrixSparseInput() { + /* + 11 12 13 14 0 0 + 0 22 23 0 0 0 + 0 0 33 34 35 36 + 0 0 0 44 45 0 + 0 0 0 0 0 56 + 0 0 0 0 0 66 + */ + DenseMatrixBuilder builder = new DenseMatrixBuilder(1024); + builder.nextColumn(0, 11).nextColumn(1, 12).nextColumn(2, 13).nextColumn(3, 14).nextRow(); + builder.nextColumn(1, 22).nextColumn(2, 23).nextRow(); + builder.nextColumn(2, 33).nextColumn(3, 34).nextColumn(4, 35).nextColumn(5, 36).nextRow(); + builder.nextColumn(3, 44).nextColumn(4, 45).nextRow(); + builder.nextColumn(5, 56).nextRow(); + builder.nextColumn(5, 66).nextRow(); + return builder.buildMatrix(true); + } + + private static Matrix denseMatrixFromLibSVM() { + DenseMatrixBuilder builder = new DenseMatrixBuilder(1024); + builder.nextRow(new String[] {"0:11", "1:12", "2:13", "3:14"}); + builder.nextRow(new String[] {"1:22", "2:23"}); + builder.nextRow(new String[] {"2:33", "3:34", "4:35", "5:36"}); + builder.nextRow(new String[] {"3:44", "4:45"}); + builder.nextRow(new String[] {"5:56"}); + builder.nextRow(new String[] {"5:66"}); + return builder.buildMatrix(true); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/054a697e/core/src/test/java/hivemall/utils/collections/DoubleArrayTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/DoubleArrayTest.java b/core/src/test/java/hivemall/utils/collections/DoubleArrayTest.java new file mode 100644 index 0000000..72e76e8 --- /dev/null +++ b/core/src/test/java/hivemall/utils/collections/DoubleArrayTest.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.collections; + +import org.junit.Assert; +import org.junit.Test; + +public class DoubleArrayTest { + + @Test + public void testSparseDoubleArrayToArray() { + SparseDoubleArray array = new SparseDoubleArray(3); + for (int i = 0; i < 10; i++) { + array.put(i, 10 + i); + } + Assert.assertEquals(10, array.size()); + Assert.assertEquals(10, array.toArray(false).length); + + double[] copied = array.toArray(true); + Assert.assertEquals(10, copied.length); + for (int i = 0; i < 10; i++) { + Assert.assertEquals(10 + i, copied[i], 0.d); + } + } + + @Test + public void testSparseDoubleArrayClear() { + SparseDoubleArray array = new SparseDoubleArray(3); + for (int i = 0; i < 10; i++) { + array.put(i, 10 + i); + } + array.clear(); + Assert.assertEquals(0, array.size()); + Assert.assertEquals(0, array.get(0), 0.d); + for (int i = 0; i < 5; i++) { + array.put(i, 100 + i); + } + Assert.assertEquals(5, array.size()); + for (int i = 0; i < 5; i++) { + Assert.assertEquals(100 + i, array.get(i), 0.d); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/054a697e/core/src/test/java/hivemall/utils/collections/IntArrayTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/IntArrayTest.java b/core/src/test/java/hivemall/utils/collections/IntArrayTest.java new file mode 100644 index 0000000..42852ea --- /dev/null +++ b/core/src/test/java/hivemall/utils/collections/IntArrayTest.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.collections; + +import org.junit.Assert; +import org.junit.Test; + +public class IntArrayTest { + + @Test + public void testFixedIntArrayToArray() { + FixedIntArray array = new FixedIntArray(11); + for (int i = 0; i < 10; i++) { + array.put(i, 10 + i); + } + Assert.assertEquals(11, array.size()); + Assert.assertEquals(11, array.toArray(false).length); + + int[] copied = array.toArray(true); + Assert.assertEquals(11, copied.length); + for (int i = 0; i < 10; i++) { + Assert.assertEquals(10 + i, copied[i]); + } + } + + @Test + public void testSparseIntArrayToArray() { + SparseIntArray array = new SparseIntArray(3); + for (int i = 0; i < 10; i++) { + array.put(i, 10 + i); + } + Assert.assertEquals(10, array.size()); + Assert.assertEquals(10, array.toArray(false).length); + + int[] copied = array.toArray(true); + Assert.assertEquals(10, copied.length); + for (int i = 0; i < 10; i++) { + Assert.assertEquals(10 + i, copied[i]); + } + } + + @Test + public void testSparseIntArrayClear() { + SparseIntArray array = new SparseIntArray(3); + for (int i = 0; i < 10; i++) { + array.put(i, 10 + i); + } + array.clear(); + Assert.assertEquals(0, array.size()); + Assert.assertEquals(0, array.get(0)); + for (int i = 0; i < 5; i++) { + array.put(i, 100 + i); + } + Assert.assertEquals(5, array.size()); + for (int i = 0; i < 5; i++) { + Assert.assertEquals(100 + i, array.get(i)); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/054a697e/docs/gitbook/tips/rt_prediction.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/tips/rt_prediction.md b/docs/gitbook/tips/rt_prediction.md index 25d9ff7..96641a3 100644 --- a/docs/gitbook/tips/rt_prediction.md +++ b/docs/gitbook/tips/rt_prediction.md @@ -199,14 +199,14 @@ Define sigmoid function used for a prediction of logistic regression as follows: ```sql DROP FUNCTION IF EXISTS sigmoid; -DELIMITER $$ +DELIMITER // CREATE FUNCTION sigmoid(x DOUBLE) RETURNS DOUBLE LANGUAGE SQL BEGIN RETURN 1.0 / (1.0 + EXP(-x)); END; -$$ +// DELIMITER ; ```
