http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/sampling/ReservoirSampler.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/sampling/ReservoirSampler.java b/core/src/main/java/hivemall/utils/sampling/ReservoirSampler.java new file mode 100644 index 0000000..1fb3a08 --- /dev/null +++ b/core/src/main/java/hivemall/utils/sampling/ReservoirSampler.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.sampling; + +import java.util.Arrays; +import java.util.List; +import java.util.Random; + +/** + * Vitter's reservoir sampling implementation that randomly chooses k items from a list containing n + * items. + * + * @link http://en.wikipedia.org/wiki/Reservoir_sampling + * @link http://portal.acm.org/citation.cfm?id=3165 + */ +public final class ReservoirSampler<T> { + + private final T[] samples; + private final int numSamples; + private int position; + + private final Random rand; + + @SuppressWarnings("unchecked") + public ReservoirSampler(int sampleSize) { + if (sampleSize <= 0) { + throw new IllegalArgumentException("sampleSize must be greater than 1: " + sampleSize); + } + this.samples = (T[]) new Object[sampleSize]; + this.numSamples = sampleSize; + this.position = 0; + this.rand = new Random(); + } + + @SuppressWarnings("unchecked") + public ReservoirSampler(int sampleSize, long seed) { + this.samples = (T[]) new Object[sampleSize]; + this.numSamples = sampleSize; + this.position = 0; + this.rand = new Random(seed); + } + + public ReservoirSampler(T[] samples) { + this.samples = samples; + this.numSamples = samples.length; + this.position = 0; + this.rand = new Random(); + } + + public ReservoirSampler(T[] samples, long seed) { + this.samples = samples; + this.numSamples = samples.length; + this.position = 0; + this.rand = new Random(seed); + } + + public T[] getSample() { + return samples; + } + + public List<T> getSamplesAsList() { + return Arrays.asList(samples); + } + + public void add(T item) { + if (item == null) { + return; + } + if (position < numSamples) {// reservoir not yet full, just append + samples[position] = item; + } else {// find a item to replace + int replaceIndex = rand.nextInt(position + 1); + if (replaceIndex < numSamples) { + samples[replaceIndex] = item; + } + } + position++; + } + + public void clear() { + Arrays.fill(samples, null); + this.position = 0; + } +}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/stream/IntIterator.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/stream/IntIterator.java b/core/src/main/java/hivemall/utils/stream/IntIterator.java new file mode 100644 index 0000000..794d81e --- /dev/null +++ b/core/src/main/java/hivemall/utils/stream/IntIterator.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.stream; + +public interface IntIterator { + + boolean hasNext(); + + int next(); + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/stream/IntStream.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/stream/IntStream.java b/core/src/main/java/hivemall/utils/stream/IntStream.java new file mode 100644 index 0000000..4130177 --- /dev/null +++ b/core/src/main/java/hivemall/utils/stream/IntStream.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.stream; + +import javax.annotation.Nonnull; + +public interface IntStream { + + @Nonnull + IntIterator iterator(); + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/stream/StreamUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/stream/StreamUtils.java b/core/src/main/java/hivemall/utils/stream/StreamUtils.java new file mode 100644 index 0000000..7bd7b63 --- /dev/null +++ b/core/src/main/java/hivemall/utils/stream/StreamUtils.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.stream; + +import hivemall.utils.io.DeflaterOutputStream; +import hivemall.utils.io.FastByteArrayInputStream; +import hivemall.utils.io.FastMultiByteArrayOutputStream; +import hivemall.utils.io.IOUtils; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.NoSuchElementException; +import java.util.zip.Deflater; +import java.util.zip.Inflater; +import java.util.zip.InflaterInputStream; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +public final class StreamUtils { + + private StreamUtils() {} + + @Nonnull + public static IntStream toCompressedIntStream(@Nonnull final int[] src) { + return toCompressedIntStream(src, Deflater.DEFAULT_COMPRESSION); + } + + @Nonnull + public static IntStream toCompressedIntStream(@Nonnull final int[] src, final int level) { + FastMultiByteArrayOutputStream bos = new FastMultiByteArrayOutputStream(16384); + Deflater deflater = new Deflater(level, true); + DeflaterOutputStream defos = new DeflaterOutputStream(bos, deflater, 8192); + DataOutputStream dos = new DataOutputStream(defos); + + final int count = src.length; + final byte[] compressed; + try { + for (int i = 0; i < count; i++) { + dos.writeInt(src[i]); + } + defos.finish(); + compressed = bos.toByteArray_clear(); + } catch (IOException e) { + throw new IllegalStateException("Failed to compress int[]", e); + } finally { + IOUtils.closeQuietly(dos); + } + + return new InflateIntStream(compressed, count); + } + + @Nonnull + public static IntStream toArrayIntStream(@Nonnull int[] array) { + return new ArrayIntStream(array); + } + + static final class ArrayIntStream implements IntStream { + + @Nonnull + private final int[] array; + + ArrayIntStream(@Nonnull int[] array) { + this.array = array; + } + + @Override + public ArrayIntIterator iterator() { + return new ArrayIntIterator(array); + } + + } + + static final class ArrayIntIterator implements IntIterator { + + @Nonnull + private final int[] array; + @Nonnegative + private final int count; + @Nonnegative + private int index; + + ArrayIntIterator(@Nonnull int[] array) { + this.array = array; + this.count = array.length; + this.index = 0; + } + + @Override + public boolean hasNext() { + return index < count; + } + + @Override + public int next() { + if (index < count) {// hasNext() + return array[index++]; + } + throw new NoSuchElementException(); + } + + } + + static final class InflateIntStream implements IntStream { + + @Nonnull + private final byte[] compressed; + @Nonnegative + private final int count; + + InflateIntStream(@Nonnull byte[] compressed, @Nonnegative int count) { + this.compressed = compressed; + this.count = count; + } + + @Override + public InflatedIntIterator iterator() { + FastByteArrayInputStream bis = new FastByteArrayInputStream(compressed); + InflaterInputStream infis = new InflaterInputStream(bis, new Inflater(true), 512); + DataInputStream in = new DataInputStream(infis); + return new InflatedIntIterator(in, count); + } + + } + + static final class InflatedIntIterator implements IntIterator { + + @Nonnull + private final DataInputStream in; + @Nonnegative + private final int count; + @Nonnegative + private int index; + + InflatedIntIterator(@Nonnull DataInputStream in, @Nonnegative int count) { + this.in = in; + this.count = count; + this.index = 0; + } + + @Override + public boolean hasNext() { + return index < count; + } + + @Override + public int next() { + if (index < count) {// hasNext() + final int v; + try { + v = in.readInt(); + } catch (IOException e) { + throw new IllegalStateException("Invalid input at " + index, e); + } + index++; + return v; + } + throw new NoSuchElementException(); + } + + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java b/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java index a65a69a..076387f 100644 --- a/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java +++ b/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java @@ -19,7 +19,7 @@ package hivemall.fm; import hivemall.utils.buffer.HeapBuffer; -import hivemall.utils.collections.Int2LongOpenHashTable; +import hivemall.utils.collections.maps.Int2LongOpenHashTable; import java.io.IOException; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java b/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java new file mode 100644 index 0000000..decd7df --- /dev/null +++ b/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java @@ -0,0 +1,644 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix; + +import hivemall.math.matrix.Matrix; +import hivemall.math.matrix.RowMajorMatrix; +import hivemall.math.matrix.builders.CSCMatrixBuilder; +import hivemall.math.matrix.builders.CSRMatrixBuilder; +import hivemall.math.matrix.builders.ColumnMajorDenseMatrixBuilder; +import hivemall.math.matrix.builders.DoKMatrixBuilder; +import hivemall.math.matrix.builders.RowMajorDenseMatrixBuilder; +import hivemall.math.matrix.dense.ColumnMajorDenseMatrix2d; +import hivemall.math.matrix.dense.RowMajorDenseMatrix2d; +import hivemall.math.matrix.sparse.CSCMatrix; +import hivemall.math.matrix.sparse.CSRMatrix; +import hivemall.math.matrix.sparse.DoKMatrix; + +import org.junit.Assert; +import org.junit.Test; + +public class MatrixBuilderTest { + + @Test + public void testReadOnlyCSRMatrix() { + Matrix matrix = csrMatrix(); + Assert.assertEquals(6, matrix.numRows()); + Assert.assertEquals(6, matrix.numColumns()); + Assert.assertEquals(4, matrix.numColumns(0)); + Assert.assertEquals(2, matrix.numColumns(1)); + Assert.assertEquals(4, matrix.numColumns(2)); + Assert.assertEquals(2, matrix.numColumns(3)); + Assert.assertEquals(1, matrix.numColumns(4)); + Assert.assertEquals(1, matrix.numColumns(5)); + + Assert.assertEquals(11d, matrix.get(0, 0), 0.d); + Assert.assertEquals(12d, matrix.get(0, 1), 0.d); + Assert.assertEquals(13d, matrix.get(0, 2), 0.d); + Assert.assertEquals(14d, matrix.get(0, 3), 0.d); + Assert.assertEquals(22d, matrix.get(1, 1), 0.d); + Assert.assertEquals(23d, matrix.get(1, 2), 0.d); + Assert.assertEquals(33d, matrix.get(2, 2), 0.d); + Assert.assertEquals(34d, matrix.get(2, 3), 0.d); + Assert.assertEquals(35d, matrix.get(2, 4), 0.d); + Assert.assertEquals(36d, matrix.get(2, 5), 0.d); + Assert.assertEquals(44d, matrix.get(3, 3), 0.d); + Assert.assertEquals(45d, matrix.get(3, 4), 0.d); + Assert.assertEquals(56d, matrix.get(4, 5), 0.d); + Assert.assertEquals(66d, matrix.get(5, 5), 0.d); + + Assert.assertEquals(0.d, matrix.get(5, 4), 0.d); + Assert.assertEquals(-1.d, matrix.get(5, 4, -1.d), 0.d); + + Assert.assertEquals(Double.NaN, matrix.get(5, 4, Double.NaN), 0.d); + } + + @Test + public void testReadOnlyCSRMatrixFromLibSVM() { + Matrix matrix = csrMatrixFromLibSVM(); + Assert.assertEquals(6, matrix.numRows()); + Assert.assertEquals(6, matrix.numColumns()); + Assert.assertEquals(4, matrix.numColumns(0)); + Assert.assertEquals(2, matrix.numColumns(1)); + Assert.assertEquals(4, matrix.numColumns(2)); + Assert.assertEquals(2, matrix.numColumns(3)); + Assert.assertEquals(1, matrix.numColumns(4)); + Assert.assertEquals(1, matrix.numColumns(5)); + + Assert.assertEquals(11d, matrix.get(0, 0), 0.d); + Assert.assertEquals(12d, matrix.get(0, 1), 0.d); + Assert.assertEquals(13d, matrix.get(0, 2), 0.d); + Assert.assertEquals(14d, matrix.get(0, 3), 0.d); + Assert.assertEquals(22d, matrix.get(1, 1), 0.d); + Assert.assertEquals(23d, matrix.get(1, 2), 0.d); + Assert.assertEquals(33d, matrix.get(2, 2), 0.d); + Assert.assertEquals(34d, matrix.get(2, 3), 0.d); + Assert.assertEquals(35d, matrix.get(2, 4), 0.d); + Assert.assertEquals(36d, matrix.get(2, 5), 0.d); + Assert.assertEquals(44d, matrix.get(3, 3), 0.d); + Assert.assertEquals(45d, matrix.get(3, 4), 0.d); + Assert.assertEquals(56d, matrix.get(4, 5), 0.d); + Assert.assertEquals(66d, matrix.get(5, 5), 0.d); + + Assert.assertEquals(0.d, matrix.get(5, 4), 0.d); + Assert.assertEquals(-1.d, matrix.get(5, 4, -1.d), 0.d); + + Assert.assertEquals(Double.NaN, matrix.get(5, 4, Double.NaN), 0.d); + } + + @Test + public void testReadOnlyCSRMatrixNoRow() { + CSRMatrixBuilder builder = new CSRMatrixBuilder(1024); + Matrix matrix = builder.buildMatrix(); + Assert.assertEquals(0, matrix.numRows()); + Assert.assertEquals(0, matrix.numColumns()); + } + + @Test(expected = IndexOutOfBoundsException.class) + public void testReadOnlyCSRMatrixGetFail1() { + Matrix matrix = csrMatrix(); + matrix.get(7, 5); + } + + @Test(expected = IndexOutOfBoundsException.class) + public void testReadOnlyCSRMatrixGetFail2() { + Matrix matrix = csrMatrix(); + matrix.get(6, 7); + } + + @Test + public void testCSCMatrixFromLibSVM() { + CSCMatrix matrix = cscMatrixFromLibSVM(); + Assert.assertEquals(6, matrix.numRows()); + Assert.assertEquals(6, matrix.numColumns()); + Assert.assertEquals(4, matrix.numColumns(0)); + Assert.assertEquals(2, matrix.numColumns(1)); + Assert.assertEquals(4, matrix.numColumns(2)); + Assert.assertEquals(2, matrix.numColumns(3)); + Assert.assertEquals(1, matrix.numColumns(4)); + Assert.assertEquals(1, matrix.numColumns(5)); + + Assert.assertEquals(11d, matrix.get(0, 0), 0.d); + Assert.assertEquals(12d, matrix.get(0, 1), 0.d); + Assert.assertEquals(13d, matrix.get(0, 2), 0.d); + Assert.assertEquals(14d, matrix.get(0, 3), 0.d); + Assert.assertEquals(22d, matrix.get(1, 1), 0.d); + Assert.assertEquals(23d, matrix.get(1, 2), 0.d); + Assert.assertEquals(33d, matrix.get(2, 2), 0.d); + Assert.assertEquals(34d, matrix.get(2, 3), 0.d); + Assert.assertEquals(35d, matrix.get(2, 4), 0.d); + Assert.assertEquals(36d, matrix.get(2, 5), 0.d); + Assert.assertEquals(44d, matrix.get(3, 3), 0.d); + Assert.assertEquals(45d, matrix.get(3, 4), 0.d); + Assert.assertEquals(56d, matrix.get(4, 5), 0.d); + Assert.assertEquals(66d, matrix.get(5, 5), 0.d); + + Assert.assertEquals(0.d, matrix.get(5, 4), 0.d); + Assert.assertEquals(-1.d, matrix.get(5, 4, -1.d), 0.d); + + Assert.assertEquals(Double.NaN, matrix.get(5, 4, Double.NaN), 0.d); + } + + @Test + public void testCSC2CSR() { + CSCMatrix csc = cscMatrixFromLibSVM(); + RowMajorMatrix csr = csc.toRowMajorMatrix(); + Assert.assertTrue(csr instanceof CSRMatrix); + Assert.assertEquals(6, csr.numRows()); + Assert.assertEquals(6, csr.numColumns()); + Assert.assertEquals(4, csr.numColumns(0)); + Assert.assertEquals(2, csr.numColumns(1)); + Assert.assertEquals(4, csr.numColumns(2)); + Assert.assertEquals(2, csr.numColumns(3)); + Assert.assertEquals(1, csr.numColumns(4)); + Assert.assertEquals(1, csr.numColumns(5)); + + Assert.assertEquals(11d, csr.get(0, 0), 0.d); + Assert.assertEquals(12d, csr.get(0, 1), 0.d); + Assert.assertEquals(13d, csr.get(0, 2), 0.d); + Assert.assertEquals(14d, csr.get(0, 3), 0.d); + Assert.assertEquals(22d, csr.get(1, 1), 0.d); + Assert.assertEquals(23d, csr.get(1, 2), 0.d); + Assert.assertEquals(33d, csr.get(2, 2), 0.d); + Assert.assertEquals(34d, csr.get(2, 3), 0.d); + Assert.assertEquals(35d, csr.get(2, 4), 0.d); + Assert.assertEquals(36d, csr.get(2, 5), 0.d); + Assert.assertEquals(44d, csr.get(3, 3), 0.d); + Assert.assertEquals(45d, csr.get(3, 4), 0.d); + Assert.assertEquals(56d, csr.get(4, 5), 0.d); + Assert.assertEquals(66d, csr.get(5, 5), 0.d); + + Assert.assertEquals(0.d, csr.get(5, 4), 0.d); + Assert.assertEquals(-1.d, csr.get(5, 4, -1.d), 0.d); + + Assert.assertEquals(Double.NaN, csr.get(5, 4, Double.NaN), 0.d); + } + + @Test + public void testCSC2CSR2CSR() { + CSCMatrix csc = cscMatrixFromLibSVM(); + CSCMatrix csc2 = csc.toRowMajorMatrix().toColumnMajorMatrix(); + Assert.assertEquals(csc.nnz(), csc2.nnz()); + Assert.assertEquals(6, csc2.numRows()); + Assert.assertEquals(6, csc2.numColumns()); + Assert.assertEquals(4, csc2.numColumns(0)); + Assert.assertEquals(2, csc2.numColumns(1)); + Assert.assertEquals(4, csc2.numColumns(2)); + Assert.assertEquals(2, csc2.numColumns(3)); + Assert.assertEquals(1, csc2.numColumns(4)); + Assert.assertEquals(1, csc2.numColumns(5)); + + Assert.assertEquals(11d, csc2.get(0, 0), 0.d); + Assert.assertEquals(12d, csc2.get(0, 1), 0.d); + Assert.assertEquals(13d, csc2.get(0, 2), 0.d); + Assert.assertEquals(14d, csc2.get(0, 3), 0.d); + Assert.assertEquals(22d, csc2.get(1, 1), 0.d); + Assert.assertEquals(23d, csc2.get(1, 2), 0.d); + Assert.assertEquals(33d, csc2.get(2, 2), 0.d); + Assert.assertEquals(34d, csc2.get(2, 3), 0.d); + Assert.assertEquals(35d, csc2.get(2, 4), 0.d); + Assert.assertEquals(36d, csc2.get(2, 5), 0.d); + Assert.assertEquals(44d, csc2.get(3, 3), 0.d); + Assert.assertEquals(45d, csc2.get(3, 4), 0.d); + Assert.assertEquals(56d, csc2.get(4, 5), 0.d); + Assert.assertEquals(66d, csc2.get(5, 5), 0.d); + + Assert.assertEquals(0.d, csc2.get(5, 4), 0.d); + Assert.assertEquals(-1.d, csc2.get(5, 4, -1.d), 0.d); + + Assert.assertEquals(Double.NaN, csc2.get(5, 4, Double.NaN), 0.d); + } + + + @Test + public void testDoKMatrixFromLibSVM() { + Matrix matrix = dokMatrixFromLibSVM(); + Assert.assertEquals(6, matrix.numRows()); + Assert.assertEquals(6, matrix.numColumns()); + Assert.assertEquals(4, matrix.numColumns(0)); + Assert.assertEquals(2, matrix.numColumns(1)); + Assert.assertEquals(4, matrix.numColumns(2)); + Assert.assertEquals(2, matrix.numColumns(3)); + Assert.assertEquals(1, matrix.numColumns(4)); + Assert.assertEquals(1, matrix.numColumns(5)); + + Assert.assertEquals(11d, matrix.get(0, 0), 0.d); + Assert.assertEquals(12d, matrix.get(0, 1), 0.d); + Assert.assertEquals(13d, matrix.get(0, 2), 0.d); + Assert.assertEquals(14d, matrix.get(0, 3), 0.d); + Assert.assertEquals(22d, matrix.get(1, 1), 0.d); + Assert.assertEquals(23d, matrix.get(1, 2), 0.d); + Assert.assertEquals(33d, matrix.get(2, 2), 0.d); + Assert.assertEquals(34d, matrix.get(2, 3), 0.d); + Assert.assertEquals(35d, matrix.get(2, 4), 0.d); + Assert.assertEquals(36d, matrix.get(2, 5), 0.d); + Assert.assertEquals(44d, matrix.get(3, 3), 0.d); + Assert.assertEquals(45d, matrix.get(3, 4), 0.d); + Assert.assertEquals(56d, matrix.get(4, 5), 0.d); + Assert.assertEquals(66d, matrix.get(5, 5), 0.d); + + Assert.assertEquals(0.d, matrix.get(5, 4), 0.d); + Assert.assertEquals(-1.d, matrix.get(5, 4, -1.d), 0.d); + + Assert.assertEquals(Double.NaN, matrix.get(5, 4, Double.NaN), 0.d); + } + + @Test + public void testReadOnlyDenseMatrix2d() { + Matrix matrix = rowMajorDenseMatrix(); + Assert.assertEquals(6, matrix.numRows()); + Assert.assertEquals(6, matrix.numColumns()); + Assert.assertEquals(4, matrix.numColumns(0)); + Assert.assertEquals(3, matrix.numColumns(1)); + Assert.assertEquals(6, matrix.numColumns(2)); + Assert.assertEquals(5, matrix.numColumns(3)); + Assert.assertEquals(6, matrix.numColumns(4)); + Assert.assertEquals(6, matrix.numColumns(5)); + + Assert.assertEquals(11d, matrix.get(0, 0), 0.d); + Assert.assertEquals(12d, matrix.get(0, 1), 0.d); + Assert.assertEquals(13d, matrix.get(0, 2), 0.d); + Assert.assertEquals(14d, matrix.get(0, 3), 0.d); + Assert.assertEquals(22d, matrix.get(1, 1), 0.d); + Assert.assertEquals(23d, matrix.get(1, 2), 0.d); + Assert.assertEquals(33d, matrix.get(2, 2), 0.d); + Assert.assertEquals(34d, matrix.get(2, 3), 0.d); + Assert.assertEquals(35d, matrix.get(2, 4), 0.d); + Assert.assertEquals(36d, matrix.get(2, 5), 0.d); + Assert.assertEquals(44d, matrix.get(3, 3), 0.d); + Assert.assertEquals(45d, matrix.get(3, 4), 0.d); + Assert.assertEquals(56d, matrix.get(4, 5), 0.d); + Assert.assertEquals(66d, matrix.get(5, 5), 0.d); + + Assert.assertEquals(0.d, matrix.get(5, 4), 0.d); + + Assert.assertEquals(0.d, matrix.get(1, 0), 0.d); + Assert.assertEquals(0.d, matrix.get(1, 3), 0.d); + Assert.assertEquals(0.d, matrix.get(1, 0), 0.d); + } + + @Test + public void testReadOnlyDenseMatrix2dSparseInput() { + Matrix matrix = denseMatrixSparseInput(); + Assert.assertEquals(6, matrix.numRows()); + Assert.assertEquals(6, matrix.numColumns()); + Assert.assertEquals(4, matrix.numColumns(0)); + Assert.assertEquals(3, matrix.numColumns(1)); + Assert.assertEquals(6, matrix.numColumns(2)); + Assert.assertEquals(5, matrix.numColumns(3)); + Assert.assertEquals(6, matrix.numColumns(4)); + Assert.assertEquals(6, matrix.numColumns(5)); + + Assert.assertEquals(11d, matrix.get(0, 0), 0.d); + Assert.assertEquals(12d, matrix.get(0, 1), 0.d); + Assert.assertEquals(13d, matrix.get(0, 2), 0.d); + Assert.assertEquals(14d, matrix.get(0, 3), 0.d); + Assert.assertEquals(22d, matrix.get(1, 1), 0.d); + Assert.assertEquals(23d, matrix.get(1, 2), 0.d); + Assert.assertEquals(33d, matrix.get(2, 2), 0.d); + Assert.assertEquals(34d, matrix.get(2, 3), 0.d); + Assert.assertEquals(35d, matrix.get(2, 4), 0.d); + Assert.assertEquals(36d, matrix.get(2, 5), 0.d); + Assert.assertEquals(44d, matrix.get(3, 3), 0.d); + Assert.assertEquals(45d, matrix.get(3, 4), 0.d); + Assert.assertEquals(56d, matrix.get(4, 5), 0.d); + Assert.assertEquals(66d, matrix.get(5, 5), 0.d); + + Assert.assertEquals(0.d, matrix.get(5, 4), 0.d); + + Assert.assertEquals(0.d, matrix.get(1, 0), 0.d); + Assert.assertEquals(0.d, matrix.get(1, 3), 0.d); + Assert.assertEquals(0.d, matrix.get(1, 0), 0.d); + } + + @Test + public void testReadOnlyDenseMatrix2dFromLibSVM() { + Matrix matrix = denseMatrixFromLibSVM(); + Assert.assertEquals(6, matrix.numRows()); + Assert.assertEquals(6, matrix.numColumns()); + Assert.assertEquals(4, matrix.numColumns(0)); + Assert.assertEquals(3, matrix.numColumns(1)); + Assert.assertEquals(6, matrix.numColumns(2)); + Assert.assertEquals(5, matrix.numColumns(3)); + Assert.assertEquals(6, matrix.numColumns(4)); + Assert.assertEquals(6, matrix.numColumns(5)); + + Assert.assertEquals(11d, matrix.get(0, 0), 0.d); + Assert.assertEquals(12d, matrix.get(0, 1), 0.d); + Assert.assertEquals(13d, matrix.get(0, 2), 0.d); + Assert.assertEquals(14d, matrix.get(0, 3), 0.d); + Assert.assertEquals(22d, matrix.get(1, 1), 0.d); + Assert.assertEquals(23d, matrix.get(1, 2), 0.d); + Assert.assertEquals(33d, matrix.get(2, 2), 0.d); + Assert.assertEquals(34d, matrix.get(2, 3), 0.d); + Assert.assertEquals(35d, matrix.get(2, 4), 0.d); + Assert.assertEquals(36d, matrix.get(2, 5), 0.d); + Assert.assertEquals(44d, matrix.get(3, 3), 0.d); + Assert.assertEquals(45d, matrix.get(3, 4), 0.d); + Assert.assertEquals(56d, matrix.get(4, 5), 0.d); + Assert.assertEquals(66d, matrix.get(5, 5), 0.d); + + Assert.assertEquals(0.d, matrix.get(5, 4), 0.d); + + Assert.assertEquals(0.d, matrix.get(1, 0), 0.d); + Assert.assertEquals(0.d, matrix.get(1, 3), 0.d); + Assert.assertEquals(0.d, matrix.get(1, 0), 0.d); + } + + @Test + public void testReadOnlyDenseMatrix2dNoRow() { + Matrix matrix = new RowMajorDenseMatrixBuilder(1024).buildMatrix(); + Assert.assertEquals(0, matrix.numRows()); + Assert.assertEquals(0, matrix.numColumns()); + } + + @Test(expected = IndexOutOfBoundsException.class) + public void testReadOnlyDenseMatrix2dFailOutOfBound1() { + Matrix matrix = rowMajorDenseMatrix(); + matrix.get(7, 5); + } + + @Test(expected = IndexOutOfBoundsException.class) + public void testReadOnlyDenseMatrix2dFailOutOfBound2() { + Matrix matrix = rowMajorDenseMatrix(); + matrix.get(6, 7); + } + + @Test + public void testColumnMajorDenseMatrix2d() { + ColumnMajorDenseMatrix2d colMatrix = columnMajorDenseMatrix(); + + Assert.assertEquals(6, colMatrix.numRows()); + Assert.assertEquals(6, colMatrix.numColumns()); + Assert.assertEquals(4, colMatrix.numColumns(0)); + Assert.assertEquals(2, colMatrix.numColumns(1)); + Assert.assertEquals(4, colMatrix.numColumns(2)); + Assert.assertEquals(2, colMatrix.numColumns(3)); + Assert.assertEquals(1, colMatrix.numColumns(4)); + Assert.assertEquals(1, colMatrix.numColumns(5)); + + Assert.assertEquals(11d, colMatrix.get(0, 0), 0.d); + Assert.assertEquals(12d, colMatrix.get(0, 1), 0.d); + Assert.assertEquals(13d, colMatrix.get(0, 2), 0.d); + Assert.assertEquals(14d, colMatrix.get(0, 3), 0.d); + Assert.assertEquals(22d, colMatrix.get(1, 1), 0.d); + Assert.assertEquals(23d, colMatrix.get(1, 2), 0.d); + Assert.assertEquals(33d, colMatrix.get(2, 2), 0.d); + Assert.assertEquals(34d, colMatrix.get(2, 3), 0.d); + Assert.assertEquals(35d, colMatrix.get(2, 4), 0.d); + Assert.assertEquals(36d, colMatrix.get(2, 5), 0.d); + Assert.assertEquals(44d, colMatrix.get(3, 3), 0.d); + Assert.assertEquals(45d, colMatrix.get(3, 4), 0.d); + Assert.assertEquals(56d, colMatrix.get(4, 5), 0.d); + Assert.assertEquals(66d, colMatrix.get(5, 5), 0.d); + + Assert.assertEquals(0.d, colMatrix.get(5, 4), 0.d); + + Assert.assertEquals(0.d, colMatrix.get(1, 0), 0.d); + Assert.assertEquals(0.d, colMatrix.get(1, 3), 0.d); + Assert.assertEquals(0.d, colMatrix.get(1, 0), 0.d); + } + + @Test + public void testDenseMatrixColumnMajor2RowMajor() { + ColumnMajorDenseMatrix2d colMatrix = columnMajorDenseMatrix(); + RowMajorDenseMatrix2d rowMatrix = colMatrix.toRowMajorMatrix(); + + Assert.assertEquals(6, rowMatrix.numRows()); + Assert.assertEquals(6, rowMatrix.numColumns()); + Assert.assertEquals(4, rowMatrix.numColumns(0)); + Assert.assertEquals(3, rowMatrix.numColumns(1)); + Assert.assertEquals(6, rowMatrix.numColumns(2)); + Assert.assertEquals(5, rowMatrix.numColumns(3)); + Assert.assertEquals(6, rowMatrix.numColumns(4)); + Assert.assertEquals(6, rowMatrix.numColumns(5)); + + Assert.assertEquals(11d, rowMatrix.get(0, 0), 0.d); + Assert.assertEquals(12d, rowMatrix.get(0, 1), 0.d); + Assert.assertEquals(13d, rowMatrix.get(0, 2), 0.d); + Assert.assertEquals(14d, rowMatrix.get(0, 3), 0.d); + Assert.assertEquals(22d, rowMatrix.get(1, 1), 0.d); + Assert.assertEquals(23d, rowMatrix.get(1, 2), 0.d); + Assert.assertEquals(33d, rowMatrix.get(2, 2), 0.d); + Assert.assertEquals(34d, rowMatrix.get(2, 3), 0.d); + Assert.assertEquals(35d, rowMatrix.get(2, 4), 0.d); + Assert.assertEquals(36d, rowMatrix.get(2, 5), 0.d); + Assert.assertEquals(44d, rowMatrix.get(3, 3), 0.d); + Assert.assertEquals(45d, rowMatrix.get(3, 4), 0.d); + Assert.assertEquals(56d, rowMatrix.get(4, 5), 0.d); + Assert.assertEquals(66d, rowMatrix.get(5, 5), 0.d); + + Assert.assertEquals(0.d, rowMatrix.get(5, 4), 0.d); + + Assert.assertEquals(0.d, rowMatrix.get(1, 0), 0.d); + Assert.assertEquals(0.d, rowMatrix.get(1, 3), 0.d); + Assert.assertEquals(0.d, rowMatrix.get(1, 0), 0.d); + + // convert back to column major matrix + + colMatrix = rowMatrix.toColumnMajorMatrix(); + + Assert.assertEquals(6, colMatrix.numRows()); + Assert.assertEquals(6, colMatrix.numColumns()); + Assert.assertEquals(4, colMatrix.numColumns(0)); + Assert.assertEquals(2, colMatrix.numColumns(1)); + Assert.assertEquals(4, colMatrix.numColumns(2)); + Assert.assertEquals(2, colMatrix.numColumns(3)); + Assert.assertEquals(1, colMatrix.numColumns(4)); + Assert.assertEquals(1, colMatrix.numColumns(5)); + + Assert.assertEquals(11d, colMatrix.get(0, 0), 0.d); + Assert.assertEquals(12d, colMatrix.get(0, 1), 0.d); + Assert.assertEquals(13d, colMatrix.get(0, 2), 0.d); + Assert.assertEquals(14d, colMatrix.get(0, 3), 0.d); + Assert.assertEquals(22d, colMatrix.get(1, 1), 0.d); + Assert.assertEquals(23d, colMatrix.get(1, 2), 0.d); + Assert.assertEquals(33d, colMatrix.get(2, 2), 0.d); + Assert.assertEquals(34d, colMatrix.get(2, 3), 0.d); + Assert.assertEquals(35d, colMatrix.get(2, 4), 0.d); + Assert.assertEquals(36d, colMatrix.get(2, 5), 0.d); + Assert.assertEquals(44d, colMatrix.get(3, 3), 0.d); + Assert.assertEquals(45d, colMatrix.get(3, 4), 0.d); + Assert.assertEquals(56d, colMatrix.get(4, 5), 0.d); + Assert.assertEquals(66d, colMatrix.get(5, 5), 0.d); + + Assert.assertEquals(0.d, colMatrix.get(5, 4), 0.d); + + Assert.assertEquals(0.d, colMatrix.get(1, 0), 0.d); + Assert.assertEquals(0.d, colMatrix.get(1, 3), 0.d); + Assert.assertEquals(0.d, colMatrix.get(1, 0), 0.d); + } + + @Test + public void testCSRMatrixNullRow() { + CSRMatrixBuilder builder = new CSRMatrixBuilder(1024); + builder.nextColumn(0, 11).nextColumn(1, 12).nextColumn(2, 13).nextColumn(3, 14).nextRow(); + builder.nextColumn(1, 22).nextColumn(2, 23).nextRow(); + builder.nextRow(); + builder.nextColumn(3, 66).nextRow(); + Matrix matrix = builder.buildMatrix(); + Assert.assertEquals(4, matrix.numRows()); + } + + private static CSRMatrix csrMatrix() { + /* + 11 12 13 14 0 0 + 0 22 23 0 0 0 + 0 0 33 34 35 36 + 0 0 0 44 45 0 + 0 0 0 0 0 56 + 0 0 0 0 0 66 + */ + CSRMatrixBuilder builder = new CSRMatrixBuilder(1024); + builder.nextColumn(0, 11).nextColumn(1, 12).nextColumn(2, 13).nextColumn(3, 14).nextRow(); + builder.nextColumn(1, 22).nextColumn(2, 23).nextRow(); + builder.nextColumn(2, 33).nextColumn(3, 34).nextColumn(4, 35).nextColumn(5, 36).nextRow(); + builder.nextColumn(3, 44).nextColumn(4, 45).nextRow(); + builder.nextColumn(5, 56).nextRow(); + builder.nextColumn(5, 66).nextRow(); + return builder.buildMatrix(); + } + + private static CSRMatrix csrMatrixFromLibSVM() { + /* + 11 12 13 14 0 0 + 0 22 23 0 0 0 + 0 0 33 34 35 36 + 0 0 0 44 45 0 + 0 0 0 0 0 56 + 0 0 0 0 0 66 + */ + CSRMatrixBuilder builder = new CSRMatrixBuilder(1024); + builder.nextRow(new String[] {"0:11", "1:12", "2:13", "3:14"}); + builder.nextRow(new String[] {"1:22", "2:23"}); + builder.nextRow(new String[] {"2:33", "3:34", "4:35", "5:36"}); + builder.nextRow(new String[] {"3:44", "4:45"}); + builder.nextRow(new String[] {"5:56"}); + builder.nextRow(new String[] {"5:66"}); + return builder.buildMatrix(); + } + + private static CSCMatrix cscMatrixFromLibSVM() { + /* + 11 12 13 14 0 0 + 0 22 23 0 0 0 + 0 0 33 34 35 36 + 0 0 0 44 45 0 + 0 0 0 0 0 56 + 0 0 0 0 0 66 + */ + CSCMatrixBuilder builder = new CSCMatrixBuilder(1024); + builder.nextRow(new String[] {"0:11", "1:12", "2:13", "3:14"}); + builder.nextRow(new String[] {"1:22", "2:23"}); + builder.nextRow(new String[] {"2:33", "3:34", "4:35", "5:36"}); + builder.nextRow(new String[] {"3:44", "4:45"}); + builder.nextRow(new String[] {"5:56"}); + builder.nextRow(new String[] {"5:66"}); + return builder.buildMatrix(); + } + + + private static DoKMatrix dokMatrixFromLibSVM() { + /* + 11 12 13 14 0 0 + 0 22 23 0 0 0 + 0 0 33 34 35 36 + 0 0 0 44 45 0 + 0 0 0 0 0 56 + 0 0 0 0 0 66 + */ + DoKMatrixBuilder builder = new DoKMatrixBuilder(1024); + builder.nextRow(new String[] {"0:11", "1:12", "2:13", "3:14"}); + builder.nextRow(new String[] {"1:22", "2:23"}); + builder.nextRow(new String[] {"2:33", "3:34", "4:35", "5:36"}); + builder.nextRow(new String[] {"3:44", "4:45"}); + builder.nextRow(new String[] {"5:56"}); + builder.nextRow(new String[] {"5:66"}); + return builder.buildMatrix(); + } + + private static RowMajorDenseMatrix2d rowMajorDenseMatrix() { + /* + 11 12 13 14 0 0 + 0 22 23 0 0 0 + 0 0 33 34 35 36 + 0 0 0 44 45 0 + 0 0 0 0 0 56 + 0 0 0 0 0 66 + */ + RowMajorDenseMatrixBuilder builder = new RowMajorDenseMatrixBuilder(1024); + builder.nextRow(new double[] {11, 12, 13, 14}); + builder.nextRow(new double[] {0, 22, 23}); + builder.nextRow(new double[] {0, 0, 33, 34, 35, 36}); + builder.nextRow(new double[] {0, 0, 0, 44, 45}); + builder.nextRow(new double[] {0, 0, 0, 0, 0, 56}); + builder.nextRow(new double[] {0, 0, 0, 0, 0, 66}); + return builder.buildMatrix(); + } + + private static ColumnMajorDenseMatrix2d columnMajorDenseMatrix() { + /* + 11 12 13 14 0 0 + 0 22 23 0 0 0 + 0 0 33 34 35 36 + 0 0 0 44 45 0 + 0 0 0 0 0 56 + 0 0 0 0 0 66 + */ + ColumnMajorDenseMatrixBuilder builder = new ColumnMajorDenseMatrixBuilder(1024); + builder.nextRow(new double[] {11, 12, 13, 14}); + builder.nextRow(new double[] {0, 22, 23}); + builder.nextRow(new double[] {0, 0, 33, 34, 35, 36}); + builder.nextRow(new double[] {0, 0, 0, 44, 45}); + builder.nextRow(new double[] {0, 0, 0, 0, 0, 56}); + builder.nextRow(new double[] {0, 0, 0, 0, 0, 66}); + return builder.buildMatrix(); + } + + private static RowMajorDenseMatrix2d denseMatrixSparseInput() { + /* + 11 12 13 14 0 0 + 0 22 23 0 0 0 + 0 0 33 34 35 36 + 0 0 0 44 45 0 + 0 0 0 0 0 56 + 0 0 0 0 0 66 + */ + RowMajorDenseMatrixBuilder builder = new RowMajorDenseMatrixBuilder(1024); + builder.nextColumn(0, 11).nextColumn(1, 12).nextColumn(2, 13).nextColumn(3, 14).nextRow(); + builder.nextColumn(1, 22).nextColumn(2, 23).nextRow(); + builder.nextColumn(2, 33).nextColumn(3, 34).nextColumn(4, 35).nextColumn(5, 36).nextRow(); + builder.nextColumn(3, 44).nextColumn(4, 45).nextRow(); + builder.nextColumn(5, 56).nextRow(); + builder.nextColumn(5, 66).nextRow(); + return builder.buildMatrix(); + } + + private static RowMajorDenseMatrix2d denseMatrixFromLibSVM() { + RowMajorDenseMatrixBuilder builder = new RowMajorDenseMatrixBuilder(1024); + builder.nextRow(new String[] {"0:11", "1:12", "2:13", "3:14"}); + builder.nextRow(new String[] {"1:22", "2:23"}); + builder.nextRow(new String[] {"2:33", "3:34", "4:35", "5:36"}); + builder.nextRow(new String[] {"3:44", "4:45"}); + builder.nextRow(new String[] {"5:56"}); + builder.nextRow(new String[] {"5:66"}); + return builder.buildMatrix(); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/math/matrix/ints/IntMatrixTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/math/matrix/ints/IntMatrixTest.java b/core/src/test/java/hivemall/math/matrix/ints/IntMatrixTest.java new file mode 100644 index 0000000..f6a52fe --- /dev/null +++ b/core/src/test/java/hivemall/math/matrix/ints/IntMatrixTest.java @@ -0,0 +1,361 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.math.matrix.ints; + +import hivemall.math.matrix.ints.ColumnMajorDenseIntMatrix2d; +import hivemall.math.matrix.ints.DoKIntMatrix; +import hivemall.math.vector.VectorProcedure; +import hivemall.utils.lang.mutable.MutableInt; + +import org.junit.Assert; +import org.junit.Test; + +public class IntMatrixTest { + + @Test + public void testDoKMatrixRowMajor() { + DoKIntMatrix matrix = DoKIntMatrix.build(rowMajorData(), true, true); + + Assert.assertEquals(6, matrix.numRows()); + Assert.assertEquals(6, matrix.numColumns()); + + Assert.assertEquals(11, matrix.get(0, 0)); + Assert.assertEquals(12, matrix.get(0, 1)); + Assert.assertEquals(13, matrix.get(0, 2)); + Assert.assertEquals(14, matrix.get(0, 3)); + Assert.assertEquals(22, matrix.get(1, 1)); + Assert.assertEquals(23, matrix.get(1, 2)); + Assert.assertEquals(33, matrix.get(2, 2)); + Assert.assertEquals(34, matrix.get(2, 3)); + Assert.assertEquals(35, matrix.get(2, 4)); + Assert.assertEquals(36, matrix.get(2, 5)); + Assert.assertEquals(44, matrix.get(3, 3)); + Assert.assertEquals(45, matrix.get(3, 4)); + Assert.assertEquals(56, matrix.get(4, 5)); + Assert.assertEquals(66, matrix.get(5, 5)); + + Assert.assertEquals(0, matrix.get(5, 4)); + Assert.assertEquals(0, matrix.get(1, 0)); + Assert.assertEquals(0, matrix.get(1, 3)); + Assert.assertEquals(-1, matrix.get(1, 0, -1)); + } + + @Test + public void testDoKMatrixColumnMajor() { + DoKIntMatrix matrix = DoKIntMatrix.build(columnMajorData(), false, true); + + Assert.assertEquals(6, matrix.numRows()); + Assert.assertEquals(6, matrix.numColumns()); + + Assert.assertEquals(11, matrix.get(0, 0)); + Assert.assertEquals(12, matrix.get(0, 1)); + Assert.assertEquals(13, matrix.get(0, 2)); + Assert.assertEquals(14, matrix.get(0, 3)); + Assert.assertEquals(22, matrix.get(1, 1)); + Assert.assertEquals(23, matrix.get(1, 2)); + Assert.assertEquals(33, matrix.get(2, 2)); + Assert.assertEquals(34, matrix.get(2, 3)); + Assert.assertEquals(35, matrix.get(2, 4)); + Assert.assertEquals(36, matrix.get(2, 5)); + Assert.assertEquals(44, matrix.get(3, 3)); + Assert.assertEquals(45, matrix.get(3, 4)); + Assert.assertEquals(56, matrix.get(4, 5)); + Assert.assertEquals(66, matrix.get(5, 5)); + + Assert.assertEquals(0, matrix.get(5, 4)); + Assert.assertEquals(0, matrix.get(1, 0)); + Assert.assertEquals(0, matrix.get(1, 3)); + Assert.assertEquals(-1, matrix.get(1, 0, -1)); + } + + @Test + public void testDoKMatrixColumnMajorNonZeroOnlyFalse() { + DoKIntMatrix matrix = DoKIntMatrix.build(columnMajorData(), false, false); + + Assert.assertEquals(6, matrix.numRows()); + Assert.assertEquals(6, matrix.numColumns()); + + Assert.assertEquals(0, matrix.get(5, 4)); + Assert.assertEquals(0, matrix.get(1, 0)); + Assert.assertEquals(0, matrix.get(1, 3)); + Assert.assertEquals(0, matrix.get(1, 3, -1)); + Assert.assertEquals(-1, matrix.get(1, 0, -1)); + + matrix.setDefaultValue(-1); + Assert.assertEquals(-1, matrix.get(5, 4)); + Assert.assertEquals(-1, matrix.get(1, 0)); + Assert.assertEquals(0, matrix.get(1, 3)); + Assert.assertEquals(0, matrix.get(1, 0, 0)); + } + + @Test + public void testColumnMajorDenseMatrix() { + ColumnMajorDenseIntMatrix2d matrix = new ColumnMajorDenseIntMatrix2d(columnMajorData(), 6); + Assert.assertEquals(6, matrix.numRows()); + Assert.assertEquals(6, matrix.numColumns()); + + Assert.assertEquals(11, matrix.get(0, 0)); + Assert.assertEquals(12, matrix.get(0, 1)); + Assert.assertEquals(13, matrix.get(0, 2)); + Assert.assertEquals(14, matrix.get(0, 3)); + Assert.assertEquals(22, matrix.get(1, 1)); + Assert.assertEquals(23, matrix.get(1, 2)); + Assert.assertEquals(33, matrix.get(2, 2)); + Assert.assertEquals(34, matrix.get(2, 3)); + Assert.assertEquals(35, matrix.get(2, 4)); + Assert.assertEquals(36, matrix.get(2, 5)); + Assert.assertEquals(44, matrix.get(3, 3)); + Assert.assertEquals(45, matrix.get(3, 4)); + Assert.assertEquals(56, matrix.get(4, 5)); + Assert.assertEquals(66, matrix.get(5, 5)); + + Assert.assertEquals(0, matrix.get(5, 4)); + Assert.assertEquals(0, matrix.get(1, 0)); + Assert.assertEquals(0, matrix.get(1, 3)); + Assert.assertEquals(-1, matrix.get(1, 0, -1)); + } + + @Test + public void testColumnMajorDenseMatrixEachColumn() { + ColumnMajorDenseIntMatrix2d matrix = new ColumnMajorDenseIntMatrix2d(columnMajorData(), 6); + matrix.setDefaultValue(-1); + + final MutableInt count = new MutableInt(0); + for (int j = 0; j < 6; j++) { + matrix.eachInColumn(j, new VectorProcedure() { + @Override + public void apply(int i, int value) { + count.addValue(1); + } + }, false); + } + Assert.assertEquals(1 + 2 + 3 + 4 + 4 + 6, count.getValue()); + + count.setValue(0); + for (int j = 0; j < 6; j++) { + matrix.eachInColumn(j, new VectorProcedure() { + @Override + public void apply(int i, int value) { + count.addValue(1); + } + }, true); + } + Assert.assertEquals(6 * 6, count.getValue()); + + count.setValue(0); + for (int j = 0; j < 6; j++) { + matrix.eachNonZeroInColumn(j, new VectorProcedure() { + @Override + public void apply(int i, int value) { + count.addValue(1); + } + }); + } + Assert.assertEquals(1 + 2 + 3 + 3 + 2 + 3, count.getValue()); + + // change default value to zero + matrix.setDefaultValue(0); + + count.setValue(0); + for (int j = 0; j < 6; j++) { + matrix.eachInColumn(j, new VectorProcedure() { + @Override + public void apply(int i, int value) { + count.addValue(1); + } + }, false); + } + Assert.assertEquals(1 + 2 + 3 + 4 + 4 + 6, count.getValue()); + + count.setValue(0); + for (int j = 0; j < 6; j++) { + matrix.eachInColumn(j, new VectorProcedure() { + @Override + public void apply(int i, int value) { + count.addValue(1); + } + }, true); + } + Assert.assertEquals(6 * 6, count.getValue()); + + count.setValue(0); + for (int j = 0; j < 6; j++) { + matrix.eachNonZeroInColumn(j, new VectorProcedure() { + @Override + public void apply(int i, int value) { + count.addValue(1); + } + }); + } + Assert.assertEquals(1 + 2 + 3 + 3 + 2 + 3, count.getValue()); + } + + @Test + public void testDoKMatrixColumnMajorNonZeroOnlyFalseEachColumn() { + DoKIntMatrix matrix = DoKIntMatrix.build(columnMajorData(), false, false); + matrix.setDefaultValue(-1); + + final MutableInt count = new MutableInt(0); + for (int j = 0; j < 6; j++) { + matrix.eachInColumn(j, new VectorProcedure() { + @Override + public void apply(int i, int value) { + count.addValue(1); + } + }, false); + } + Assert.assertEquals(1 + 2 + 3 + 4 + 4 + 6, count.getValue()); + + count.setValue(0); + for (int j = 0; j < 6; j++) { + matrix.eachInColumn(j, new VectorProcedure() { + @Override + public void apply(int i, int value) { + count.addValue(1); + } + }, true); + } + Assert.assertEquals(6 * 6, count.getValue()); + + count.setValue(0); + for (int j = 0; j < 6; j++) { + matrix.eachNonZeroInColumn(j, new VectorProcedure() { + @Override + public void apply(int i, int value) { + count.addValue(1); + } + }); + } + Assert.assertEquals(1 + 2 + 3 + 3 + 2 + 3, count.getValue()); + + // change default value to zero + matrix.setDefaultValue(0); + + count.setValue(0); + for (int j = 0; j < 6; j++) { + matrix.eachInColumn(j, new VectorProcedure() { + @Override + public void apply(int i, int value) { + count.addValue(1); + } + }, false); + } + Assert.assertEquals(1 + 2 + 3 + 4 + 4 + 6, count.getValue()); + + count.setValue(0); + for (int j = 0; j < 6; j++) { + matrix.eachInColumn(j, new VectorProcedure() { + @Override + public void apply(int i, int value) { + count.addValue(1); + } + }, true); + } + Assert.assertEquals(6 * 6, count.getValue()); + + count.setValue(0); + for (int j = 0; j < 6; j++) { + matrix.eachNonZeroInColumn(j, new VectorProcedure() { + @Override + public void apply(int i, int value) { + count.addValue(1); + } + }); + } + Assert.assertEquals(1 + 2 + 3 + 3 + 2 + 3, count.getValue()); + } + + @Test + public void testDoKMatrixRowMajorNonZeroOnlyFalseEachColumn() { + DoKIntMatrix matrix = DoKIntMatrix.build(rowMajorData(), true, false); + matrix.setDefaultValue(-1); + + final MutableInt count = new MutableInt(0); + for (int i = 0; i < 6; i++) { + matrix.eachInRow(i, new VectorProcedure() { + @Override + public void apply(int i, int value) { + count.addValue(1); + } + }, false); + } + Assert.assertEquals(4 + 3 + 6 + 5 + 6 + 6, count.getValue()); + + count.setValue(0); + for (int i = 0; i < 6; i++) { + matrix.eachInRow(i, new VectorProcedure() { + @Override + public void apply(int i, int value) { + count.addValue(1); + } + }, true); + } + Assert.assertEquals(6 * 6, count.getValue()); + + count.setValue(0); + for (int i = 0; i < 6; i++) { + matrix.eachNonZeroInRow(i, new VectorProcedure() { + @Override + public void apply(int i, int value) { + count.addValue(1); + } + }); + } + Assert.assertEquals(4 + 2 + 4 + 2 + 1 + 1, count.getValue()); + } + + private static int[][] rowMajorData() { + /* + 11 12 13 14 0 0 + 0 22 23 0 0 0 + 0 0 33 34 35 36 + 0 0 0 44 45 0 + 0 0 0 0 0 56 + 0 0 0 0 0 66 + */ + int[][] data = new int[6][]; + data[0] = new int[] {11, 12, 13, 14}; + data[1] = new int[] {0, 22, 23}; + data[2] = new int[] {0, 0, 33, 34, 35, 36}; + data[3] = new int[] {0, 0, 0, 44, 45}; + data[4] = new int[] {0, 0, 0, 0, 0, 56}; + data[5] = new int[] {0, 0, 0, 0, 0, 66}; + return data; + } + + private static int[][] columnMajorData() { + /* + 11 12 13 14 0 0 + 0 22 23 0 0 0 + 0 0 33 34 35 36 + 0 0 0 44 45 0 + 0 0 0 0 0 56 + 0 0 0 0 0 66 + */ + int[][] data = new int[6][]; + data[0] = new int[] {11}; + data[1] = new int[] {12, 22}; + data[2] = new int[] {13, 23, 33}; + data[3] = new int[] {14, 0, 34, 44}; + data[4] = new int[] {0, 0, 35, 45}; + data[5] = new int[] {0, 0, 36, 0, 56, 66}; + return data; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/matrix/MatrixBuilderTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/matrix/MatrixBuilderTest.java b/core/src/test/java/hivemall/matrix/MatrixBuilderTest.java deleted file mode 100644 index 5545631..0000000 --- a/core/src/test/java/hivemall/matrix/MatrixBuilderTest.java +++ /dev/null @@ -1,329 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.matrix; - -import org.junit.Assert; -import org.junit.Test; - -public class MatrixBuilderTest { - - @Test - public void testReadOnlyCSRMatrix() { - Matrix matrix = csrMatrix(); - Assert.assertEquals(6, matrix.numRows()); - Assert.assertEquals(6, matrix.numColumns()); - Assert.assertEquals(4, matrix.numColumns(0)); - Assert.assertEquals(2, matrix.numColumns(1)); - Assert.assertEquals(4, matrix.numColumns(2)); - Assert.assertEquals(2, matrix.numColumns(3)); - Assert.assertEquals(1, matrix.numColumns(4)); - Assert.assertEquals(1, matrix.numColumns(5)); - - Assert.assertEquals(11d, matrix.get(0, 0), 0.d); - Assert.assertEquals(12d, matrix.get(0, 1), 0.d); - Assert.assertEquals(13d, matrix.get(0, 2), 0.d); - Assert.assertEquals(14d, matrix.get(0, 3), 0.d); - Assert.assertEquals(22d, matrix.get(1, 1), 0.d); - Assert.assertEquals(23d, matrix.get(1, 2), 0.d); - Assert.assertEquals(33d, matrix.get(2, 2), 0.d); - Assert.assertEquals(34d, matrix.get(2, 3), 0.d); - Assert.assertEquals(35d, matrix.get(2, 4), 0.d); - Assert.assertEquals(36d, matrix.get(2, 5), 0.d); - Assert.assertEquals(44d, matrix.get(3, 3), 0.d); - Assert.assertEquals(45d, matrix.get(3, 4), 0.d); - Assert.assertEquals(56d, matrix.get(4, 5), 0.d); - Assert.assertEquals(66d, matrix.get(5, 5), 0.d); - - Assert.assertEquals(0.d, matrix.get(5, 4), 0.d); - Assert.assertEquals(-1.d, matrix.get(5, 4, -1.d), 0.d); - - matrix.setDefaultValue(Double.NaN); - Assert.assertEquals(Double.NaN, matrix.get(5, 4), 0.d); - } - - @Test - public void testReadOnlyCSRMatrixFromLibSVM() { - Matrix matrix = csrMatrixFromLibSVM(); - Assert.assertEquals(6, matrix.numRows()); - Assert.assertEquals(6, matrix.numColumns()); - Assert.assertEquals(4, matrix.numColumns(0)); - Assert.assertEquals(2, matrix.numColumns(1)); - Assert.assertEquals(4, matrix.numColumns(2)); - Assert.assertEquals(2, matrix.numColumns(3)); - Assert.assertEquals(1, matrix.numColumns(4)); - Assert.assertEquals(1, matrix.numColumns(5)); - - Assert.assertEquals(11d, matrix.get(0, 0), 0.d); - Assert.assertEquals(12d, matrix.get(0, 1), 0.d); - Assert.assertEquals(13d, matrix.get(0, 2), 0.d); - Assert.assertEquals(14d, matrix.get(0, 3), 0.d); - Assert.assertEquals(22d, matrix.get(1, 1), 0.d); - Assert.assertEquals(23d, matrix.get(1, 2), 0.d); - Assert.assertEquals(33d, matrix.get(2, 2), 0.d); - Assert.assertEquals(34d, matrix.get(2, 3), 0.d); - Assert.assertEquals(35d, matrix.get(2, 4), 0.d); - Assert.assertEquals(36d, matrix.get(2, 5), 0.d); - Assert.assertEquals(44d, matrix.get(3, 3), 0.d); - Assert.assertEquals(45d, matrix.get(3, 4), 0.d); - Assert.assertEquals(56d, matrix.get(4, 5), 0.d); - Assert.assertEquals(66d, matrix.get(5, 5), 0.d); - - Assert.assertEquals(0.d, matrix.get(5, 4), 0.d); - Assert.assertEquals(-1.d, matrix.get(5, 4, -1.d), 0.d); - - matrix.setDefaultValue(Double.NaN); - Assert.assertEquals(Double.NaN, matrix.get(5, 4), 0.d); - } - - - @Test - public void testReadOnlyCSRMatrixNoRow() { - CSRMatrixBuilder builder = new CSRMatrixBuilder(1024); - Matrix matrix = builder.buildMatrix(true); - Assert.assertEquals(0, matrix.numRows()); - Assert.assertEquals(0, matrix.numColumns()); - } - - @Test(expected = IndexOutOfBoundsException.class) - public void testReadOnlyCSRMatrixGetFail1() { - Matrix matrix = csrMatrix(); - matrix.get(7, 5); - } - - @Test(expected = IndexOutOfBoundsException.class) - public void testReadOnlyCSRMatrixGetFail2() { - Matrix matrix = csrMatrix(); - matrix.get(6, 7); - } - - @Test - public void testReadOnlyDenseMatrix2d() { - Matrix matrix = denseMatrix(); - Assert.assertEquals(6, matrix.numRows()); - Assert.assertEquals(6, matrix.numColumns()); - Assert.assertEquals(4, matrix.numColumns(0)); - Assert.assertEquals(3, matrix.numColumns(1)); - Assert.assertEquals(6, matrix.numColumns(2)); - Assert.assertEquals(5, matrix.numColumns(3)); - Assert.assertEquals(6, matrix.numColumns(4)); - Assert.assertEquals(6, matrix.numColumns(5)); - - Assert.assertEquals(11d, matrix.get(0, 0), 0.d); - Assert.assertEquals(12d, matrix.get(0, 1), 0.d); - Assert.assertEquals(13d, matrix.get(0, 2), 0.d); - Assert.assertEquals(14d, matrix.get(0, 3), 0.d); - Assert.assertEquals(22d, matrix.get(1, 1), 0.d); - Assert.assertEquals(23d, matrix.get(1, 2), 0.d); - Assert.assertEquals(33d, matrix.get(2, 2), 0.d); - Assert.assertEquals(34d, matrix.get(2, 3), 0.d); - Assert.assertEquals(35d, matrix.get(2, 4), 0.d); - Assert.assertEquals(36d, matrix.get(2, 5), 0.d); - Assert.assertEquals(44d, matrix.get(3, 3), 0.d); - Assert.assertEquals(45d, matrix.get(3, 4), 0.d); - Assert.assertEquals(56d, matrix.get(4, 5), 0.d); - Assert.assertEquals(66d, matrix.get(5, 5), 0.d); - - Assert.assertEquals(0.d, matrix.get(5, 4), 0.d); - - Assert.assertEquals(0.d, matrix.get(1, 0), 0.d); - Assert.assertEquals(0.d, matrix.get(1, 3), 0.d); - Assert.assertEquals(0.d, matrix.get(1, 0), 0.d); - } - - @Test - public void testReadOnlyDenseMatrix2dSparseInput() { - Matrix matrix = denseMatrixSparseInput(); - Assert.assertEquals(6, matrix.numRows()); - Assert.assertEquals(6, matrix.numColumns()); - Assert.assertEquals(4, matrix.numColumns(0)); - Assert.assertEquals(3, matrix.numColumns(1)); - Assert.assertEquals(6, matrix.numColumns(2)); - Assert.assertEquals(5, matrix.numColumns(3)); - Assert.assertEquals(6, matrix.numColumns(4)); - Assert.assertEquals(6, matrix.numColumns(5)); - - Assert.assertEquals(11d, matrix.get(0, 0), 0.d); - Assert.assertEquals(12d, matrix.get(0, 1), 0.d); - Assert.assertEquals(13d, matrix.get(0, 2), 0.d); - Assert.assertEquals(14d, matrix.get(0, 3), 0.d); - Assert.assertEquals(22d, matrix.get(1, 1), 0.d); - Assert.assertEquals(23d, matrix.get(1, 2), 0.d); - Assert.assertEquals(33d, matrix.get(2, 2), 0.d); - Assert.assertEquals(34d, matrix.get(2, 3), 0.d); - Assert.assertEquals(35d, matrix.get(2, 4), 0.d); - Assert.assertEquals(36d, matrix.get(2, 5), 0.d); - Assert.assertEquals(44d, matrix.get(3, 3), 0.d); - Assert.assertEquals(45d, matrix.get(3, 4), 0.d); - Assert.assertEquals(56d, matrix.get(4, 5), 0.d); - Assert.assertEquals(66d, matrix.get(5, 5), 0.d); - - Assert.assertEquals(0.d, matrix.get(5, 4), 0.d); - - Assert.assertEquals(0.d, matrix.get(1, 0), 0.d); - Assert.assertEquals(0.d, matrix.get(1, 3), 0.d); - Assert.assertEquals(0.d, matrix.get(1, 0), 0.d); - } - - @Test - public void testReadOnlyDenseMatrix2dFromLibSVM() { - Matrix matrix = denseMatrixFromLibSVM(); - Assert.assertEquals(6, matrix.numRows()); - Assert.assertEquals(6, matrix.numColumns()); - Assert.assertEquals(4, matrix.numColumns(0)); - Assert.assertEquals(3, matrix.numColumns(1)); - Assert.assertEquals(6, matrix.numColumns(2)); - Assert.assertEquals(5, matrix.numColumns(3)); - Assert.assertEquals(6, matrix.numColumns(4)); - Assert.assertEquals(6, matrix.numColumns(5)); - - Assert.assertEquals(11d, matrix.get(0, 0), 0.d); - Assert.assertEquals(12d, matrix.get(0, 1), 0.d); - Assert.assertEquals(13d, matrix.get(0, 2), 0.d); - Assert.assertEquals(14d, matrix.get(0, 3), 0.d); - Assert.assertEquals(22d, matrix.get(1, 1), 0.d); - Assert.assertEquals(23d, matrix.get(1, 2), 0.d); - Assert.assertEquals(33d, matrix.get(2, 2), 0.d); - Assert.assertEquals(34d, matrix.get(2, 3), 0.d); - Assert.assertEquals(35d, matrix.get(2, 4), 0.d); - Assert.assertEquals(36d, matrix.get(2, 5), 0.d); - Assert.assertEquals(44d, matrix.get(3, 3), 0.d); - Assert.assertEquals(45d, matrix.get(3, 4), 0.d); - Assert.assertEquals(56d, matrix.get(4, 5), 0.d); - Assert.assertEquals(66d, matrix.get(5, 5), 0.d); - - Assert.assertEquals(0.d, matrix.get(5, 4), 0.d); - - Assert.assertEquals(0.d, matrix.get(1, 0), 0.d); - Assert.assertEquals(0.d, matrix.get(1, 3), 0.d); - Assert.assertEquals(0.d, matrix.get(1, 0), 0.d); - } - - @Test - public void testReadOnlyDenseMatrix2dNoRow() { - Matrix matrix = new DenseMatrixBuilder(1024).buildMatrix(true); - Assert.assertEquals(0, matrix.numRows()); - Assert.assertEquals(0, matrix.numColumns()); - } - - @Test(expected = UnsupportedOperationException.class) - public void testReadOnlyDenseMatrix2dFailToChangeDefaultValue() { - Matrix matrix = denseMatrix(); - matrix.setDefaultValue(Double.NaN); - } - - @Test(expected = IndexOutOfBoundsException.class) - public void testReadOnlyDenseMatrix2dFailOutOfBound1() { - Matrix matrix = denseMatrix(); - matrix.get(7, 5); - } - - @Test(expected = IndexOutOfBoundsException.class) - public void testReadOnlyDenseMatrix2dFailOutOfBound2() { - Matrix matrix = denseMatrix(); - matrix.get(6, 7); - } - - private static Matrix csrMatrix() { - /* - 11 12 13 14 0 0 - 0 22 23 0 0 0 - 0 0 33 34 35 36 - 0 0 0 44 45 0 - 0 0 0 0 0 56 - 0 0 0 0 0 66 - */ - CSRMatrixBuilder builder = new CSRMatrixBuilder(1024); - builder.nextColumn(0, 11).nextColumn(1, 12).nextColumn(2, 13).nextColumn(3, 14).nextRow(); - builder.nextColumn(1, 22).nextColumn(2, 23).nextRow(); - builder.nextColumn(2, 33).nextColumn(3, 34).nextColumn(4, 35).nextColumn(5, 36).nextRow(); - builder.nextColumn(3, 44).nextColumn(4, 45).nextRow(); - builder.nextColumn(5, 56).nextRow(); - builder.nextColumn(5, 66).nextRow(); - return builder.buildMatrix(true); - } - - private static Matrix csrMatrixFromLibSVM() { - /* - 11 12 13 14 0 0 - 0 22 23 0 0 0 - 0 0 33 34 35 36 - 0 0 0 44 45 0 - 0 0 0 0 0 56 - 0 0 0 0 0 66 - */ - CSRMatrixBuilder builder = new CSRMatrixBuilder(1024); - builder.nextRow(new String[] {"0:11", "1:12", "2:13", "3:14"}); - builder.nextRow(new String[] {"1:22", "2:23"}); - builder.nextRow(new String[] {"2:33", "3:34", "4:35", "5:36"}); - builder.nextRow(new String[] {"3:44", "4:45"}); - builder.nextRow(new String[] {"5:56"}); - builder.nextRow(new String[] {"5:66"}); - return builder.buildMatrix(true); - } - - private static Matrix denseMatrix() { - /* - 11 12 13 14 0 0 - 0 22 23 0 0 0 - 0 0 33 34 35 36 - 0 0 0 44 45 0 - 0 0 0 0 0 56 - 0 0 0 0 0 66 - */ - DenseMatrixBuilder builder = new DenseMatrixBuilder(1024); - builder.nextRow(new double[] {11, 12, 13, 14}); - builder.nextRow(new double[] {0, 22, 23}); - builder.nextRow(new double[] {0, 0, 33, 34, 35, 36}); - builder.nextRow(new double[] {0, 0, 0, 44, 45}); - builder.nextRow(new double[] {0, 0, 0, 0, 0, 56}); - builder.nextRow(new double[] {0, 0, 0, 0, 0, 66}); - return builder.buildMatrix(true); - } - - private static Matrix denseMatrixSparseInput() { - /* - 11 12 13 14 0 0 - 0 22 23 0 0 0 - 0 0 33 34 35 36 - 0 0 0 44 45 0 - 0 0 0 0 0 56 - 0 0 0 0 0 66 - */ - DenseMatrixBuilder builder = new DenseMatrixBuilder(1024); - builder.nextColumn(0, 11).nextColumn(1, 12).nextColumn(2, 13).nextColumn(3, 14).nextRow(); - builder.nextColumn(1, 22).nextColumn(2, 23).nextRow(); - builder.nextColumn(2, 33).nextColumn(3, 34).nextColumn(4, 35).nextColumn(5, 36).nextRow(); - builder.nextColumn(3, 44).nextColumn(4, 45).nextRow(); - builder.nextColumn(5, 56).nextRow(); - builder.nextColumn(5, 66).nextRow(); - return builder.buildMatrix(true); - } - - private static Matrix denseMatrixFromLibSVM() { - DenseMatrixBuilder builder = new DenseMatrixBuilder(1024); - builder.nextRow(new String[] {"0:11", "1:12", "2:13", "3:14"}); - builder.nextRow(new String[] {"1:22", "2:23"}); - builder.nextRow(new String[] {"2:33", "3:34", "4:35", "5:36"}); - builder.nextRow(new String[] {"3:44", "4:45"}); - builder.nextRow(new String[] {"5:56"}); - builder.nextRow(new String[] {"5:66"}); - return builder.buildMatrix(true); - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java b/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java index 3c6116c..bb6de6b 100644 --- a/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java +++ b/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java @@ -19,13 +19,13 @@ package hivemall.smile.classification; import static org.junit.Assert.assertEquals; -import hivemall.smile.ModelType; +import hivemall.math.matrix.Matrix; +import hivemall.math.matrix.builders.CSRMatrixBuilder; +import hivemall.math.matrix.dense.RowMajorDenseMatrix2d; +import hivemall.math.random.RandomNumberGeneratorFactory; import hivemall.smile.classification.DecisionTree.Node; import hivemall.smile.data.Attribute; -import hivemall.smile.tools.TreePredictUDF; import hivemall.smile.utils.SmileExtUtils; -import hivemall.smile.vm.StackMachine; -import hivemall.utils.lang.ArrayUtils; import java.io.BufferedInputStream; import java.io.IOException; @@ -33,14 +33,9 @@ import java.io.InputStream; import java.net.URL; import java.text.ParseException; +import javax.annotation.Nonnull; + import org.apache.hadoop.hive.ql.metadata.HiveException; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredJavaObject; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; -import org.apache.hadoop.io.IntWritable; import org.junit.Assert; import org.junit.Test; @@ -52,85 +47,76 @@ import smile.validation.LOOCV; public class DecisionTreeTest { private static final boolean DEBUG = false; - /** - * Test of learn method, of class DecisionTree. - * - * @throws ParseException - * @throws IOException - */ @Test public void testWeather() throws IOException, ParseException { - URL url = new URL( - "https://gist.githubusercontent.com/myui/2c9df50db3de93a71b92/raw/3f6b4ecfd4045008059e1a2d1c4064fb8a3d372a/weather.nominal.arff"); - InputStream is = new BufferedInputStream(url.openStream()); - - ArffParser arffParser = new ArffParser(); - arffParser.setResponseIndex(4); - - AttributeDataset weather = arffParser.parse(is); - double[][] x = weather.toArray(new double[weather.size()][]); - int[] y = weather.toArray(new int[weather.size()]); - - int n = x.length; - LOOCV loocv = new LOOCV(n); - int error = 0; - for (int i = 0; i < n; i++) { - double[][] trainx = Math.slice(x, loocv.train[i]); - int[] trainy = Math.slice(y, loocv.train[i]); + int responseIndex = 4; + int numLeafs = 3; - Attribute[] attrs = SmileExtUtils.convertAttributeTypes(weather.attributes()); - DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 3); - if (y[loocv.test[i]] != tree.predict(x[loocv.test[i]])) - error++; - } + // dense matrix + int error = run( + "https://gist.githubusercontent.com/myui/2c9df50db3de93a71b92/raw/3f6b4ecfd4045008059e1a2d1c4064fb8a3d372a/weather.nominal.arff", + responseIndex, numLeafs, true); + assertEquals(5, error); - debugPrint("Decision Tree error = " + error); + // sparse matrix + error = run( + "https://gist.githubusercontent.com/myui/2c9df50db3de93a71b92/raw/3f6b4ecfd4045008059e1a2d1c4064fb8a3d372a/weather.nominal.arff", + responseIndex, numLeafs, false); assertEquals(5, error); } @Test public void testIris() throws IOException, ParseException { - URL url = new URL( - "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"); - InputStream is = new BufferedInputStream(url.openStream()); - - ArffParser arffParser = new ArffParser(); - arffParser.setResponseIndex(4); - - AttributeDataset iris = arffParser.parse(is); - double[][] x = iris.toArray(new double[iris.size()][]); - int[] y = iris.toArray(new int[iris.size()]); - - int n = x.length; - LOOCV loocv = new LOOCV(n); - int error = 0; - for (int i = 0; i < n; i++) { - double[][] trainx = Math.slice(x, loocv.train[i]); - int[] trainy = Math.slice(y, loocv.train[i]); - - Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes()); - smile.math.Random rand = new smile.math.Random(i); - DecisionTree tree = new DecisionTree(attrs, trainx, trainy, Integer.MAX_VALUE, rand); - if (y[loocv.test[i]] != tree.predict(x[loocv.test[i]])) - error++; - } + int responseIndex = 4; + int numLeafs = Integer.MAX_VALUE; + int error = run( + "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff", + responseIndex, numLeafs, true); + assertEquals(8, error); - debugPrint("Decision Tree error = " + error); + // sparse + error = run( + "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff", + responseIndex, numLeafs, false); assertEquals(8, error); } @Test + public void testIrisSparseDenseEquals() throws IOException, ParseException { + int responseIndex = 4; + int numLeafs = Integer.MAX_VALUE; + runAndCompareSparseAndDense( + "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff", + responseIndex, numLeafs); + } + + @Test public void testIrisDepth4() throws IOException, ParseException { - URL url = new URL( - "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"); + int responseIndex = 4; + int numLeafs = 4; + int error = run( + "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff", + responseIndex, numLeafs, true); + assertEquals(7, error); + + // sparse + error = run( + "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff", + responseIndex, numLeafs, false); + assertEquals(7, error); + } + + private static int run(String datasetUrl, int responseIndex, int numLeafs, boolean dense) + throws IOException, ParseException { + URL url = new URL(datasetUrl); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); - arffParser.setResponseIndex(4); + arffParser.setResponseIndex(responseIndex); - AttributeDataset iris = arffParser.parse(is); - double[][] x = iris.toArray(new double[iris.size()][]); - int[] y = iris.toArray(new int[iris.size()]); + AttributeDataset ds = arffParser.parse(is); + double[][] x = ds.toArray(new double[ds.size()][]); + int[] y = ds.toArray(new int[ds.size()]); int n = x.length; LOOCV loocv = new LOOCV(n); @@ -139,52 +125,29 @@ public class DecisionTreeTest { double[][] trainx = Math.slice(x, loocv.train[i]); int[] trainy = Math.slice(y, loocv.train[i]); - Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes()); - DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 4); - if (y[loocv.test[i]] != tree.predict(x[loocv.test[i]])) + Attribute[] attrs = SmileExtUtils.convertAttributeTypes(ds.attributes()); + DecisionTree tree = new DecisionTree(attrs, matrix(trainx, dense), trainy, numLeafs, + RandomNumberGeneratorFactory.createPRNG(i)); + if (y[loocv.test[i]] != tree.predict(x[loocv.test[i]])) { error++; + } } debugPrint("Decision Tree error = " + error); - assertEquals(7, error); + return error; } - @Test - public void testIrisStackmachine() throws IOException, ParseException, HiveException { - URL url = new URL( - "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"); + private static void runAndCompareSparseAndDense(String datasetUrl, int responseIndex, + int numLeafs) throws IOException, ParseException { + URL url = new URL(datasetUrl); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); - arffParser.setResponseIndex(4); - AttributeDataset iris = arffParser.parse(is); - double[][] x = iris.toArray(new double[iris.size()][]); - int[] y = iris.toArray(new int[iris.size()]); - - int n = x.length; - LOOCV loocv = new LOOCV(n); - for (int i = 0; i < n; i++) { - double[][] trainx = Math.slice(x, loocv.train[i]); - int[] trainy = Math.slice(y, loocv.train[i]); - - Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes()); - DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 4); - assertEquals(tree.predict(x[loocv.test[i]]), - predictByStackMachine(tree, x[loocv.test[i]])); - } - } - - @Test - public void testIrisJavascript() throws IOException, ParseException, HiveException { - URL url = new URL( - "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"); - InputStream is = new BufferedInputStream(url.openStream()); + arffParser.setResponseIndex(responseIndex); - ArffParser arffParser = new ArffParser(); - arffParser.setResponseIndex(4); - AttributeDataset iris = arffParser.parse(is); - double[][] x = iris.toArray(new double[iris.size()][]); - int[] y = iris.toArray(new int[iris.size()]); + AttributeDataset ds = arffParser.parse(is); + double[][] x = ds.toArray(new double[ds.size()][]); + int[] y = ds.toArray(new int[ds.size()]); int n = x.length; LOOCV loocv = new LOOCV(n); @@ -192,10 +155,12 @@ public class DecisionTreeTest { double[][] trainx = Math.slice(x, loocv.train[i]); int[] trainy = Math.slice(y, loocv.train[i]); - Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes()); - DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 4); - assertEquals(tree.predict(x[loocv.test[i]]), - predictByJavascript(tree, x[loocv.test[i]])); + Attribute[] attrs = SmileExtUtils.convertAttributeTypes(ds.attributes()); + DecisionTree dtree = new DecisionTree(attrs, matrix(trainx, true), trainy, numLeafs, + RandomNumberGeneratorFactory.createPRNG(i)); + DecisionTree stree = new DecisionTree(attrs, matrix(trainx, false), trainy, numLeafs, + RandomNumberGeneratorFactory.createPRNG(i)); + Assert.assertEquals(dtree.predict(x[loocv.test[i]]), stree.predict(x[loocv.test[i]])); } } @@ -218,7 +183,7 @@ public class DecisionTreeTest { int[] trainy = Math.slice(y, loocv.train[i]); Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes()); - DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 4); + DecisionTree tree = new DecisionTree(attrs, matrix(trainx, true), trainy, 4); byte[] b = tree.predictSerCodegen(false); Node node = DecisionTree.deserializeNode(b, b.length, false); @@ -245,7 +210,7 @@ public class DecisionTreeTest { int[] trainy = Math.slice(y, loocv.train[i]); Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes()); - DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 4); + DecisionTree tree = new DecisionTree(attrs, matrix(trainx, true), trainy, 4); byte[] b1 = tree.predictSerCodegen(true); byte[] b2 = tree.predictSerCodegen(false); @@ -256,52 +221,18 @@ public class DecisionTreeTest { } } - private static int predictByStackMachine(DecisionTree tree, double[] x) throws HiveException, - IOException { - String script = tree.predictOpCodegen(StackMachine.SEP); - debugPrint(script); - - TreePredictUDF udf = new TreePredictUDF(); - udf.initialize(new ObjectInspector[] { - PrimitiveObjectInspectorFactory.javaStringObjectInspector, - PrimitiveObjectInspectorFactory.javaIntObjectInspector, - PrimitiveObjectInspectorFactory.javaStringObjectInspector, - ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector), - ObjectInspectorUtils.getConstantObjectInspector( - PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, true)}); - DeferredObject[] arguments = new DeferredObject[] {new DeferredJavaObject("model_id#1"), - new DeferredJavaObject(ModelType.opscode.getId()), new DeferredJavaObject(script), - new DeferredJavaObject(ArrayUtils.toList(x)), new DeferredJavaObject(true)}; - - IntWritable result = (IntWritable) udf.evaluate(arguments); - result = (IntWritable) udf.evaluate(arguments); - udf.close(); - return result.get(); - } - - private static int predictByJavascript(DecisionTree tree, double[] x) throws HiveException, - IOException { - String script = tree.predictJsCodegen(); - debugPrint(script); - - TreePredictUDF udf = new TreePredictUDF(); - udf.initialize(new ObjectInspector[] { - PrimitiveObjectInspectorFactory.javaStringObjectInspector, - PrimitiveObjectInspectorFactory.javaIntObjectInspector, - PrimitiveObjectInspectorFactory.javaStringObjectInspector, - ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector), - ObjectInspectorUtils.getConstantObjectInspector( - PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, true)}); - - DeferredObject[] arguments = new DeferredObject[] {new DeferredJavaObject("model_id#1"), - new DeferredJavaObject(ModelType.javascript.getId()), - new DeferredJavaObject(script), new DeferredJavaObject(ArrayUtils.toList(x)), - new DeferredJavaObject(true)}; - - IntWritable result = (IntWritable) udf.evaluate(arguments); - result = (IntWritable) udf.evaluate(arguments); - udf.close(); - return result.get(); + @Nonnull + private static Matrix matrix(@Nonnull final double[][] x, boolean dense) { + if (dense) { + return new RowMajorDenseMatrix2d(x, x[0].length); + } else { + int numRows = x.length; + CSRMatrixBuilder builder = new CSRMatrixBuilder(1024); + for (int i = 0; i < numRows; i++) { + builder.nextRow(x[i]); + } + return builder.buildMatrix(); + } } private static void debugPrint(String msg) {
