http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/IntPairWritableTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/IntPairWritableTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/IntPairWritableTest.java new file mode 100644 index 0000000..ceffe3e --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/IntPairWritableTest.java @@ -0,0 +1,114 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInput; +import java.io.DataInputStream; +import java.io.DataOutput; +import java.io.DataOutputStream; +import java.util.Arrays; + +import org.junit.Test; + +public final class IntPairWritableTest extends MahoutTestCase { + + @Test + public void testGetSet() { + IntPairWritable n = new IntPairWritable(); + + assertEquals(0, n.getFirst()); + assertEquals(0, n.getSecond()); + + n.setFirst(5); + n.setSecond(10); + + assertEquals(5, n.getFirst()); + assertEquals(10, n.getSecond()); + + n = new IntPairWritable(2,4); + + assertEquals(2, n.getFirst()); + assertEquals(4, n.getSecond()); + } + + @Test + public void testWritable() throws Exception { + IntPairWritable one = new IntPairWritable(1,2); + IntPairWritable two = new IntPairWritable(3,4); + + assertEquals(1, one.getFirst()); + assertEquals(2, one.getSecond()); + + assertEquals(3, two.getFirst()); + assertEquals(4, two.getSecond()); + + + ByteArrayOutputStream bout = new ByteArrayOutputStream(); + DataOutput out = new DataOutputStream(bout); + + two.write(out); + + byte[] b = bout.toByteArray(); + + ByteArrayInputStream bin = new ByteArrayInputStream(b); + DataInput din = new DataInputStream(bin); + + one.readFields(din); + + assertEquals(two.getFirst(), one.getFirst()); + assertEquals(two.getSecond(), one.getSecond()); + } + + @Test + public void testComparable() { + IntPairWritable[] input = { + new IntPairWritable(2,3), + new IntPairWritable(2,2), + new IntPairWritable(1,3), + new IntPairWritable(1,2), + new IntPairWritable(2,1), + new IntPairWritable(2,2), + new IntPairWritable(1,-2), + new IntPairWritable(1,-1), + new IntPairWritable(-2,-2), + new IntPairWritable(-2,-1), + new IntPairWritable(-1,-1), + new IntPairWritable(-1,-2), + new IntPairWritable(Integer.MAX_VALUE,1), + new IntPairWritable(Integer.MAX_VALUE/2,1), + new IntPairWritable(Integer.MIN_VALUE,1), + new IntPairWritable(Integer.MIN_VALUE/2,1) + + }; + + IntPairWritable[] sorted = new IntPairWritable[input.length]; + System.arraycopy(input, 0, sorted, 0, input.length); + Arrays.sort(sorted); + + int[] expected = { + 14, 15, 8, 9, 11, 10, 6, 7, 3, 2, 4, 1, 5, 0, 13, 12 + }; + + for (int i=0; i < input.length; i++) { + assertSame(input[expected[i]], sorted[i]); + } + + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/MahoutTestCase.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/MahoutTestCase.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/MahoutTestCase.java new file mode 100644 index 0000000..775c8d8 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/MahoutTestCase.java @@ -0,0 +1,148 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.io.Writer; +import java.lang.reflect.Field; + +import com.google.common.base.Charsets; +import com.google.common.io.Closeables; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.junit.After; +import org.junit.Before; + +public class MahoutTestCase extends org.apache.mahout.math.MahoutTestCase { + + /** "Close enough" value for floating-point comparisons. */ + public static final double EPSILON = 0.000001; + + private Path testTempDirPath; + private FileSystem fs; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + RandomUtils.useTestSeed(); + testTempDirPath = null; + fs = null; + } + + @Override + @After + public void tearDown() throws Exception { + if (testTempDirPath != null) { + try { + fs.delete(testTempDirPath, true); + } catch (IOException e) { + throw new IllegalStateException("Test file not found"); + } + testTempDirPath = null; + fs = null; + } + super.tearDown(); + } + + public final Configuration getConfiguration() throws IOException { + Configuration conf = new Configuration(); + conf.set("hadoop.tmp.dir", getTestTempDir("hadoop" + Math.random()).getAbsolutePath()); + return conf; + } + + protected final Path getTestTempDirPath() throws IOException { + if (testTempDirPath == null) { + fs = FileSystem.get(getConfiguration()); + long simpleRandomLong = (long) (Long.MAX_VALUE * Math.random()); + testTempDirPath = fs.makeQualified( + new Path("/tmp/mahout-" + getClass().getSimpleName() + '-' + simpleRandomLong)); + if (!fs.mkdirs(testTempDirPath)) { + throw new IOException("Could not create " + testTempDirPath); + } + fs.deleteOnExit(testTempDirPath); + } + return testTempDirPath; + } + + protected final Path getTestTempFilePath(String name) throws IOException { + return getTestTempFileOrDirPath(name, false); + } + + protected final Path getTestTempDirPath(String name) throws IOException { + return getTestTempFileOrDirPath(name, true); + } + + private Path getTestTempFileOrDirPath(String name, boolean dir) throws IOException { + Path testTempDirPath = getTestTempDirPath(); + Path tempFileOrDir = fs.makeQualified(new Path(testTempDirPath, name)); + fs.deleteOnExit(tempFileOrDir); + if (dir && !fs.mkdirs(tempFileOrDir)) { + throw new IOException("Could not create " + tempFileOrDir); + } + return tempFileOrDir; + } + + /** + * Try to directly set a (possibly private) field on an Object + */ + protected static void setField(Object target, String fieldname, Object value) + throws NoSuchFieldException, IllegalAccessException { + Field field = findDeclaredField(target.getClass(), fieldname); + field.setAccessible(true); + field.set(target, value); + } + + /** + * Find a declared field in a class or one of it's super classes + */ + private static Field findDeclaredField(Class<?> inClass, String fieldname) throws NoSuchFieldException { + while (!Object.class.equals(inClass)) { + for (Field field : inClass.getDeclaredFields()) { + if (field.getName().equalsIgnoreCase(fieldname)) { + return field; + } + } + inClass = inClass.getSuperclass(); + } + throw new NoSuchFieldException(); + } + + /** + * @return a job option key string (--name) from the given option name + */ + protected static String optKey(String optionName) { + return AbstractJob.keyFor(optionName); + } + + protected static void writeLines(File file, String... lines) throws IOException { + Writer writer = new OutputStreamWriter(new FileOutputStream(file), Charsets.UTF_8); + try { + for (String line : lines) { + writer.write(line); + writer.write('\n'); + } + } finally { + Closeables.close(writer, false); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/MockIterator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/MockIterator.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/MockIterator.java new file mode 100644 index 0000000..ce48fdc --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/MockIterator.java @@ -0,0 +1,51 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common; + +import org.apache.hadoop.io.DataInputBuffer; +import org.apache.hadoop.mapred.RawKeyValueIterator; +import org.apache.hadoop.util.Progress; + +public final class MockIterator implements RawKeyValueIterator { + + @Override + public void close() { + } + + @Override + public DataInputBuffer getKey() { + return null; + } + + @Override + public Progress getProgress() { + return null; + } + + @Override + public DataInputBuffer getValue() { + + return null; + } + + @Override + public boolean next() { + return true; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/StringUtilsTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/StringUtilsTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/StringUtilsTest.java new file mode 100644 index 0000000..0633685 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/StringUtilsTest.java @@ -0,0 +1,70 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common; + +import com.google.common.collect.Lists; +import org.junit.Test; + +import java.util.List; + +public final class StringUtilsTest extends MahoutTestCase { + + private static class DummyTest { + private int field; + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof DummyTest)) { + return false; + } + + DummyTest dt = (DummyTest) obj; + return field == dt.field; + } + + @Override + public int hashCode() { + return field; + } + + public int getField() { + return field; + } + } + + @Test + public void testStringConversion() throws Exception { + + List<String> expected = Lists.newArrayList("A", "B", "C"); + assertEquals(expected, StringUtils.fromString(StringUtils + .toString(expected))); + + // test a non serializable object + DummyTest test = new DummyTest(); + assertEquals(test, StringUtils.fromString(StringUtils.toString(test))); + } + + @Test + public void testEscape() throws Exception { + String res = StringUtils.escapeXML("\",\',&,>,<"); + assertEquals("_,_,_,_,_", res); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/CosineDistanceMeasureTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/CosineDistanceMeasureTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/CosineDistanceMeasureTest.java new file mode 100644 index 0000000..6db7c9b --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/CosineDistanceMeasureTest.java @@ -0,0 +1,66 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +public final class CosineDistanceMeasureTest extends MahoutTestCase { + + @Test + public void testMeasure() { + + DistanceMeasure distanceMeasure = new CosineDistanceMeasure(); + + Vector[] vectors = { + new DenseVector(new double[]{1, 0, 0, 0, 0, 0}), + new DenseVector(new double[]{1, 1, 1, 0, 0, 0}), + new DenseVector(new double[]{1, 1, 1, 1, 1, 1}) + }; + + double[][] distanceMatrix = new double[3][3]; + + for (int a = 0; a < 3; a++) { + for (int b = 0; b < 3; b++) { + distanceMatrix[a][b] = distanceMeasure.distance(vectors[a], vectors[b]); + } + } + + assertEquals(0.0, distanceMatrix[0][0], EPSILON); + assertTrue(distanceMatrix[0][0] < distanceMatrix[0][1]); + assertTrue(distanceMatrix[0][1] < distanceMatrix[0][2]); + + assertEquals(0.0, distanceMatrix[1][1], EPSILON); + assertTrue(distanceMatrix[1][0] > distanceMatrix[1][1]); + assertTrue(distanceMatrix[1][2] < distanceMatrix[1][0]); + + assertEquals(0.0, distanceMatrix[2][2], EPSILON); + assertTrue(distanceMatrix[2][0] > distanceMatrix[2][1]); + assertTrue(distanceMatrix[2][1] > distanceMatrix[2][2]); + + // Two equal vectors (despite them being zero) should have 0 distance. + assertEquals(0, + distanceMeasure.distance(new SequentialAccessSparseVector(1), + new SequentialAccessSparseVector(1)), + EPSILON); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/DefaultDistanceMeasureTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/DefaultDistanceMeasureTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/DefaultDistanceMeasureTest.java new file mode 100644 index 0000000..ad1608c --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/DefaultDistanceMeasureTest.java @@ -0,0 +1,103 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +public abstract class DefaultDistanceMeasureTest extends MahoutTestCase { + + protected abstract DistanceMeasure distanceMeasureFactory(); + + @Test + public void testMeasure() { + + DistanceMeasure distanceMeasure = distanceMeasureFactory(); + + Vector[] vectors = { + new DenseVector(new double[]{1, 1, 1, 1, 1, 1}), + new DenseVector(new double[]{2, 2, 2, 2, 2, 2}), + new DenseVector(new double[]{6, 6, 6, 6, 6, 6}), + new DenseVector(new double[]{-1,-1,-1,-1,-1,-1}) + }; + + compare(distanceMeasure, vectors); + + vectors = new Vector[4]; + + vectors[0] = new RandomAccessSparseVector(5); + vectors[0].setQuick(0, 1); + vectors[0].setQuick(3, 1); + vectors[0].setQuick(4, 1); + + vectors[1] = new RandomAccessSparseVector(5); + vectors[1].setQuick(0, 2); + vectors[1].setQuick(3, 2); + vectors[1].setQuick(4, 2); + + vectors[2] = new RandomAccessSparseVector(5); + vectors[2].setQuick(0, 6); + vectors[2].setQuick(3, 6); + vectors[2].setQuick(4, 6); + + vectors[3] = new RandomAccessSparseVector(5); + + compare(distanceMeasure, vectors); + } + + private static void compare(DistanceMeasure distanceMeasure, Vector[] vectors) { + double[][] distanceMatrix = new double[4][4]; + + for (int a = 0; a < 4; a++) { + for (int b = 0; b < 4; b++) { + distanceMatrix[a][b] = distanceMeasure.distance(vectors[a], vectors[b]); + } + } + + assertEquals("Distance from first vector to itself is not zero", 0.0, distanceMatrix[0][0], EPSILON); + assertTrue(distanceMatrix[0][0] < distanceMatrix[0][1]); + assertTrue(distanceMatrix[0][1] < distanceMatrix[0][2]); + + assertEquals("Distance from second vector to itself is not zero", 0.0, distanceMatrix[1][1], EPSILON); + assertTrue(distanceMatrix[1][0] > distanceMatrix[1][1]); + assertTrue(distanceMatrix[1][2] > distanceMatrix[1][0]); + + assertEquals("Distance from third vector to itself is not zero", 0.0, distanceMatrix[2][2], EPSILON); + assertTrue(distanceMatrix[2][0] > distanceMatrix[2][1]); + assertTrue(distanceMatrix[2][1] > distanceMatrix[2][2]); + + for (int a = 0; a < 4; a++) { + for (int b = 0; b < 4; b++) { + assertTrue("Distance between vectors less than zero: " + + distanceMatrix[a][b] + " = " + distanceMeasure + + ".distance("+ vectors[a].asFormatString() + ", " + + vectors[b].asFormatString() + ')', + distanceMatrix[a][b] >= 0); + if (vectors[a].plus(vectors[b]).norm(2) == 0 && vectors[a].norm(2) > 0) { + assertTrue("Distance from v to -v is equal to zero" + + vectors[a].asFormatString() + " = -" + vectors[b].asFormatString(), + distanceMatrix[a][b] > 0); + } + } + } + + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/DefaultWeightedDistanceMeasureTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/DefaultWeightedDistanceMeasureTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/DefaultWeightedDistanceMeasureTest.java new file mode 100644 index 0000000..a8f1d0b --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/DefaultWeightedDistanceMeasureTest.java @@ -0,0 +1,56 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +public abstract class DefaultWeightedDistanceMeasureTest extends DefaultDistanceMeasureTest { + + @Override + public abstract WeightedDistanceMeasure distanceMeasureFactory(); + + @Test + public void testMeasureWeighted() { + + WeightedDistanceMeasure distanceMeasure = distanceMeasureFactory(); + + Vector[] vectors = { + new DenseVector(new double[]{9, 9, 1}), + new DenseVector(new double[]{1, 9, 9}), + new DenseVector(new double[]{9, 1, 9}), + }; + distanceMeasure.setWeights(new DenseVector(new double[]{1, 1000, 1})); + + double[][] distanceMatrix = new double[3][3]; + + for (int a = 0; a < 3; a++) { + for (int b = 0; b < 3; b++) { + distanceMatrix[a][b] = distanceMeasure.distance(vectors[a], vectors[b]); + } + } + + assertEquals(0.0, distanceMatrix[0][0], EPSILON); + assertTrue(distanceMatrix[0][1] < distanceMatrix[0][2]); + + + } + + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestChebyshevMeasure.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestChebyshevMeasure.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestChebyshevMeasure.java new file mode 100644 index 0000000..185adf3 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestChebyshevMeasure.java @@ -0,0 +1,55 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +public final class TestChebyshevMeasure extends MahoutTestCase { + + @Test + public void testMeasure() { + + DistanceMeasure chebyshevDistanceMeasure = new ChebyshevDistanceMeasure(); + + Vector[] vectors = { + new DenseVector(new double[]{1, 0, 0, 0, 0, 0}), + new DenseVector(new double[]{1, 1, 1, 0, 0, 0}), + new DenseVector(new double[]{1, 1, 1, 1, 1, 1}) + }; + double[][] distances = {{0.0, 1.0, 1.0}, {1.0, 0.0, 1.0}, {1.0, 1.0, 0.0}}; + + double[][] chebyshevDistanceMatrix = new double[3][3]; + + for (int a = 0; a < 3; a++) { + for (int b = 0; b < 3; b++) { + chebyshevDistanceMatrix[a][b] = chebyshevDistanceMeasure.distance(vectors[a], vectors[b]); + } + } + for (int a = 0; a < 3; a++) { + for (int b = 0; b < 3; b++) { + assertEquals(distances[a][b], chebyshevDistanceMatrix[a][b], EPSILON); + } + } + + assertEquals(0.0, chebyshevDistanceMatrix[0][0], EPSILON); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestEuclideanDistanceMeasure.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestEuclideanDistanceMeasure.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestEuclideanDistanceMeasure.java new file mode 100644 index 0000000..cc9e9e7 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestEuclideanDistanceMeasure.java @@ -0,0 +1,26 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +public final class TestEuclideanDistanceMeasure extends DefaultDistanceMeasureTest { + + @Override + public DistanceMeasure distanceMeasureFactory() { + return new EuclideanDistanceMeasure(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestMahalanobisDistanceMeasure.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestMahalanobisDistanceMeasure.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestMahalanobisDistanceMeasure.java new file mode 100644 index 0000000..8e3d205 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestMahalanobisDistanceMeasure.java @@ -0,0 +1,56 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.junit.Test; + + +/** + * To launch this test only : mvn test -Dtest=org.apache.mahout.common.distance.TestMahalanobisDistanceMeasure + */ +public final class TestMahalanobisDistanceMeasure extends MahoutTestCase { + + @Test + public void testMeasure() { + double[][] invCovValues = { { 2.2, 0.4 }, { 0.4, 2.8 } }; + double[] meanValues = { -2.3, -0.9 }; + Matrix invCov = new DenseMatrix(invCovValues); + Vector meanVector = new DenseVector(meanValues); + MahalanobisDistanceMeasure distanceMeasure = new MahalanobisDistanceMeasure(); + distanceMeasure.setInverseCovarianceMatrix(invCov); + distanceMeasure.setMeanVector(meanVector); + double[] v1 = { -1.9, -2.3 }; + double[] v2 = { -2.9, -1.3 }; + double dist = distanceMeasure.distance(new DenseVector(v1),new DenseVector(v2)); + assertEquals(2.0493901531919194, dist, EPSILON); + //now set the covariance Matrix + distanceMeasure.setCovarianceMatrix(invCov); + //check the inverse covariance times covariance equals identity + Matrix identity = distanceMeasure.getInverseCovarianceMatrix().times(invCov); + assertEquals(1, identity.get(0,0), EPSILON); + assertEquals(1, identity.get(1,1), EPSILON); + assertEquals(0, identity.get(1,0), EPSILON); + assertEquals(0, identity.get(0,1), EPSILON); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestManhattanDistanceMeasure.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestManhattanDistanceMeasure.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestManhattanDistanceMeasure.java new file mode 100644 index 0000000..97a5612 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestManhattanDistanceMeasure.java @@ -0,0 +1,26 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +public final class TestManhattanDistanceMeasure extends DefaultDistanceMeasureTest { + + @Override + public DistanceMeasure distanceMeasureFactory() { + return new ManhattanDistanceMeasure(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestMinkowskiMeasure.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestMinkowskiMeasure.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestMinkowskiMeasure.java new file mode 100644 index 0000000..d2cd85e --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestMinkowskiMeasure.java @@ -0,0 +1,64 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +public final class TestMinkowskiMeasure extends MahoutTestCase { + + @Test + public void testMeasure() { + + DistanceMeasure minkowskiDistanceMeasure = new MinkowskiDistanceMeasure(1.5); + DistanceMeasure manhattanDistanceMeasure = new ManhattanDistanceMeasure(); + DistanceMeasure euclideanDistanceMeasure = new EuclideanDistanceMeasure(); + + Vector[] vectors = { + new DenseVector(new double[]{1, 0, 0, 0, 0, 0}), + new DenseVector(new double[]{1, 1, 1, 0, 0, 0}), + new DenseVector(new double[]{1, 1, 1, 1, 1, 1}) + }; + + double[][] minkowskiDistanceMatrix = new double[3][3]; + double[][] manhattanDistanceMatrix = new double[3][3]; + double[][] euclideanDistanceMatrix = new double[3][3]; + + for (int a = 0; a < 3; a++) { + for (int b = 0; b < 3; b++) { + minkowskiDistanceMatrix[a][b] = minkowskiDistanceMeasure.distance(vectors[a], vectors[b]); + manhattanDistanceMatrix[a][b] = manhattanDistanceMeasure.distance(vectors[a], vectors[b]); + euclideanDistanceMatrix[a][b] = euclideanDistanceMeasure.distance(vectors[a], vectors[b]); + } + } + + for (int a = 0; a < 3; a++) { + for (int b = 0; b < 3; b++) { + assertTrue(minkowskiDistanceMatrix[a][b] <= manhattanDistanceMatrix[a][b]); + assertTrue(minkowskiDistanceMatrix[a][b] >= euclideanDistanceMatrix[a][b]); + } + } + + assertEquals(0.0, minkowskiDistanceMatrix[0][0], EPSILON); + assertTrue(minkowskiDistanceMatrix[0][0] < minkowskiDistanceMatrix[0][1]); + assertTrue(minkowskiDistanceMatrix[0][1] < minkowskiDistanceMatrix[0][2]); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestTanimotoDistanceMeasure.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestTanimotoDistanceMeasure.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestTanimotoDistanceMeasure.java new file mode 100644 index 0000000..01f9134 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestTanimotoDistanceMeasure.java @@ -0,0 +1,25 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +public final class TestTanimotoDistanceMeasure extends DefaultWeightedDistanceMeasureTest { + @Override + public TanimotoDistanceMeasure distanceMeasureFactory() { + return new TanimotoDistanceMeasure(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestWeightedEuclideanDistanceMeasureTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestWeightedEuclideanDistanceMeasureTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestWeightedEuclideanDistanceMeasureTest.java new file mode 100644 index 0000000..b99d165 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestWeightedEuclideanDistanceMeasureTest.java @@ -0,0 +1,25 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +public final class TestWeightedEuclideanDistanceMeasureTest extends DefaultWeightedDistanceMeasureTest { + @Override + public WeightedDistanceMeasure distanceMeasureFactory() { + return new WeightedEuclideanDistanceMeasure(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestWeightedManhattanDistanceMeasure.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestWeightedManhattanDistanceMeasure.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestWeightedManhattanDistanceMeasure.java new file mode 100644 index 0000000..77d4a01 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/distance/TestWeightedManhattanDistanceMeasure.java @@ -0,0 +1,26 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +public final class TestWeightedManhattanDistanceMeasure extends DefaultWeightedDistanceMeasureTest { + + @Override + public WeightedManhattanDistanceMeasure distanceMeasureFactory() { + return new WeightedManhattanDistanceMeasure(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/iterator/CountingIteratorTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/iterator/CountingIteratorTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/iterator/CountingIteratorTest.java new file mode 100644 index 0000000..d38178c --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/iterator/CountingIteratorTest.java @@ -0,0 +1,44 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.iterator; + +import java.util.Iterator; + +import org.apache.mahout.common.MahoutTestCase; +import org.junit.Test; + +public final class CountingIteratorTest extends MahoutTestCase { + + @Test + public void testEmptyCase() { + assertFalse(new CountingIterator(0).hasNext()); + } + + @Test + public void testCount() { + Iterator<Integer> it = new CountingIterator(3); + assertTrue(it.hasNext()); + assertEquals(0, (int) it.next()); + assertTrue(it.hasNext()); + assertEquals(1, (int) it.next()); + assertTrue(it.hasNext()); + assertEquals(2, (int) it.next()); + assertFalse(it.hasNext()); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/iterator/SamplerCase.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/iterator/SamplerCase.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/iterator/SamplerCase.java new file mode 100644 index 0000000..b67d34b --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/iterator/SamplerCase.java @@ -0,0 +1,101 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.iterator; + +import java.util.Collections; +import java.util.Iterator; +import java.util.Arrays; +import java.util.List; + +import org.apache.mahout.common.MahoutTestCase; +import org.junit.Test; + +public abstract class SamplerCase extends MahoutTestCase { + // these provide access to the underlying implementation + protected abstract Iterator<Integer> createSampler(int n, Iterator<Integer> source); + + protected abstract boolean isSorted(); + + @Test + public void testEmptyCase() { + assertFalse(createSampler(100, new CountingIterator(0)).hasNext()); + } + + @Test + public void testSmallInput() { + Iterator<Integer> t = createSampler(10, new CountingIterator(1)); + assertTrue(t.hasNext()); + assertEquals(0, t.next().intValue()); + assertFalse(t.hasNext()); + + t = createSampler(10, new CountingIterator(1)); + assertTrue(t.hasNext()); + assertEquals(0, t.next().intValue()); + assertFalse(t.hasNext()); + } + + @Test + public void testAbsurdSize() { + Iterator<Integer> t = createSampler(0, new CountingIterator(2)); + assertFalse(t.hasNext()); + } + + @Test + public void testExactSizeMatch() { + Iterator<Integer> t = createSampler(10, new CountingIterator(10)); + for (int i = 0; i < 10; i++) { + assertTrue(t.hasNext()); + assertEquals(i, t.next().intValue()); + } + assertFalse(t.hasNext()); + } + + @Test + public void testSample() { + Iterator<Integer> source = new CountingIterator(100); + Iterator<Integer> t = createSampler(15, source); + + // this is just a regression test, not a real test + List<Integer> expectedValues = Arrays.asList(52,28,2,60,50,32,65,79,78,9,40,33,96,25,48); + if (isSorted()) { + Collections.sort(expectedValues); + } + Iterator<Integer> expected = expectedValues.iterator(); + int last = Integer.MIN_VALUE; + for (int i = 0; i < 15; i++) { + assertTrue(t.hasNext()); + int actual = t.next(); + if (isSorted()) { + assertTrue(actual >= last); + last = actual; + } else { + // any of the first few values should be in the original places + if (actual < 15) { + assertEquals(i, actual); + } + } + + assertTrue(actual >= 0 && actual < 100); + + // this is just a regression test, but still of some value + assertEquals(expected.next().intValue(), actual); + assertFalse(source.hasNext()); + } + assertFalse(t.hasNext()); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/iterator/TestFixedSizeSampler.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/iterator/TestFixedSizeSampler.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/iterator/TestFixedSizeSampler.java new file mode 100644 index 0000000..91e092f --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/iterator/TestFixedSizeSampler.java @@ -0,0 +1,33 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.iterator; + +import java.util.Iterator; + +public final class TestFixedSizeSampler extends SamplerCase { + + @Override + protected Iterator<Integer> createSampler(int n, Iterator<Integer> source) { + return new FixedSizeSamplingIterator<>(n, source); + } + + @Override + protected boolean isSorted() { + return false; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/iterator/TestSamplingIterator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/iterator/TestSamplingIterator.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/iterator/TestSamplingIterator.java new file mode 100644 index 0000000..802eb86 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/iterator/TestSamplingIterator.java @@ -0,0 +1,77 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.iterator; + +import java.util.Iterator; + +import org.apache.mahout.common.MahoutTestCase; +import org.junit.Test; + +public final class TestSamplingIterator extends MahoutTestCase { + + @Test + public void testEmptyCase() { + assertFalse(new SamplingIterator<>(new CountingIterator(0), 0.9999).hasNext()); + assertFalse(new SamplingIterator<>(new CountingIterator(0), 1).hasNext()); + } + + @Test + public void testSmallInput() { + Iterator<Integer> t = new SamplingIterator<>(new CountingIterator(1), 0.9999); + assertTrue(t.hasNext()); + assertEquals(0, t.next().intValue()); + assertFalse(t.hasNext()); + } + + @Test(expected = IllegalArgumentException.class) + public void testBadRate1() { + new SamplingIterator<>(new CountingIterator(1), 0.0); + } + + @Test(expected = IllegalArgumentException.class) + public void testBadRate2() { + new SamplingIterator<>(new CountingIterator(1), 1.1); + } + + @Test + public void testExactSizeMatch() { + Iterator<Integer> t = new SamplingIterator<>(new CountingIterator(10), 1); + for (int i = 0; i < 10; i++) { + assertTrue(t.hasNext()); + assertEquals(i, t.next().intValue()); + } + assertFalse(t.hasNext()); + } + + @Test + public void testSample() { + for (int i = 0; i < 1000; i++) { + Iterator<Integer> t = new SamplingIterator<>(new CountingIterator(1000), 0.1); + int k = 0; + while (t.hasNext()) { + int v = t.next(); + k++; + assertTrue(v >= 0); + assertTrue(v < 1000); + } + double sd = Math.sqrt(0.9 * 0.1 * 1000); + assertTrue(k >= 100 - 4 * sd); + assertTrue(k <= 100 + 4 * sd); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/iterator/TestStableFixedSizeSampler.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/iterator/TestStableFixedSizeSampler.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/iterator/TestStableFixedSizeSampler.java new file mode 100644 index 0000000..558899f --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/iterator/TestStableFixedSizeSampler.java @@ -0,0 +1,33 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.iterator; + +import java.util.Iterator; + +public final class TestStableFixedSizeSampler extends SamplerCase { + + @Override + protected Iterator<Integer> createSampler(int n, Iterator<Integer> source) { + return new StableFixedSizeSamplingIterator<>(n, source); + } + + @Override + protected boolean isSorted() { + return true; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/lucene/AnalyzerUtilsTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/lucene/AnalyzerUtilsTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/lucene/AnalyzerUtilsTest.java new file mode 100644 index 0000000..f94d63e --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/lucene/AnalyzerUtilsTest.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.lucene; + +import org.apache.lucene.analysis.cjk.CJKAnalyzer; +import org.apache.lucene.analysis.standard.StandardAnalyzer; +import org.junit.Test; + +import static org.junit.Assert.assertNotNull; + +public class AnalyzerUtilsTest { + + @Test + public void createStandardAnalyzer() throws Exception { + assertNotNull(AnalyzerUtils.createAnalyzer(StandardAnalyzer.class.getName())); + } + + @Test + public void createCJKAnalyzer() throws Exception { + assertNotNull(AnalyzerUtils.createAnalyzer(CJKAnalyzer.class.getName())); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/driver/MahoutDriverTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/driver/MahoutDriverTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/driver/MahoutDriverTest.java new file mode 100644 index 0000000..e0bdc98 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/driver/MahoutDriverTest.java @@ -0,0 +1,32 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.driver; + +import org.junit.Test; + +/** + * Tests if MahoutDriver can be run directly through its main method. + */ +public final class MahoutDriverTest { + + @Test + public void testMain() throws Throwable { + MahoutDriver.main(new String[] {"canopy", "help"}); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java new file mode 100644 index 0000000..e7a3b3e --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java @@ -0,0 +1,81 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.ep; + +import org.apache.mahout.common.MahoutTestCase; +import org.junit.Test; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +public final class EvolutionaryProcessTest extends MahoutTestCase { + + @Test + public void testConverges() throws Exception { + State<Foo, Double> s0 = new State<>(new double[5], 1); + s0.setPayload(new Foo()); + EvolutionaryProcess<Foo, Double> ep = new EvolutionaryProcess<>(10, 100, s0); + + State<Foo, Double> best = null; + for (int i = 0; i < 20; i++) { + best = ep.parallelDo(new EvolutionaryProcess.Function<Payload<Double>>() { + @Override + public double apply(Payload<Double> payload, double[] params) { + int i = 1; + double sum = 0; + for (double x : params) { + sum += i * (x - i) * (x - i); + i++; + } + return -sum; + } + }); + + ep.mutatePopulation(3); + + System.out.printf("%10.3f %.3f\n", best.getValue(), best.getOmni()); + } + + ep.close(); + assertNotNull(best); + assertEquals(0.0, best.getValue(), 0.02); + } + + private static class Foo implements Payload<Double> { + @Override + public Foo copy() { + return this; + } + + @Override + public void update(double[] params) { + // ignore + } + + @Override + public void write(DataOutput dataOutput) throws IOException { + // no-op + } + + @Override + public void readFields(DataInput dataInput) throws IOException { + // no-op + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/math/MatrixWritableTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/math/MatrixWritableTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/math/MatrixWritableTest.java new file mode 100644 index 0000000..226d4b1 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/math/MatrixWritableTest.java @@ -0,0 +1,148 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.Map; + +import com.google.common.collect.Maps; +import com.google.common.io.Closeables; +import org.apache.hadoop.io.Writable; +import org.junit.Test; + +public final class MatrixWritableTest extends MahoutTestCase { + + @Test + public void testSparseMatrixWritable() throws Exception { + Matrix m = new SparseMatrix(5, 5); + m.set(1, 2, 3.0); + m.set(3, 4, 5.0); + Map<String, Integer> bindings = Maps.newHashMap(); + bindings.put("A", 0); + bindings.put("B", 1); + bindings.put("C", 2); + bindings.put("D", 3); + bindings.put("default", 4); + m.setRowLabelBindings(bindings); + m.setColumnLabelBindings(bindings); + doTestMatrixWritableEquals(m); + } + + @Test + public void testSparseRowMatrixWritable() throws Exception { + Matrix m = new SparseRowMatrix(5, 5); + m.set(1, 2, 3.0); + m.set(3, 4, 5.0); + Map<String, Integer> bindings = Maps.newHashMap(); + bindings.put("A", 0); + bindings.put("B", 1); + bindings.put("C", 2); + bindings.put("D", 3); + bindings.put("default", 4); + m.setRowLabelBindings(bindings); + m.setColumnLabelBindings(bindings); + doTestMatrixWritableEquals(m); + } + + @Test + public void testDenseMatrixWritable() throws Exception { + Matrix m = new DenseMatrix(5,5); + m.set(1, 2, 3.0); + m.set(3, 4, 5.0); + Map<String, Integer> bindings = Maps.newHashMap(); + bindings.put("A", 0); + bindings.put("B", 1); + bindings.put("C", 2); + bindings.put("D", 3); + bindings.put("default", 4); + m.setRowLabelBindings(bindings); + m.setColumnLabelBindings(bindings); + doTestMatrixWritableEquals(m); + } + + private static void doTestMatrixWritableEquals(Matrix m) throws IOException { + Writable matrixWritable = new MatrixWritable(m); + MatrixWritable matrixWritable2 = new MatrixWritable(); + writeAndRead(matrixWritable, matrixWritable2); + Matrix m2 = matrixWritable2.get(); + compareMatrices(m, m2); + doCheckBindings(m2.getRowLabelBindings()); + doCheckBindings(m2.getColumnLabelBindings()); + } + + private static void compareMatrices(Matrix m, Matrix m2) { + assertEquals(m.numRows(), m2.numRows()); + assertEquals(m.numCols(), m2.numCols()); + for (int r = 0; r < m.numRows(); r++) { + for (int c = 0; c < m.numCols(); c++) { + assertEquals(m.get(r, c), m2.get(r, c), EPSILON); + } + } + Map<String,Integer> bindings = m.getRowLabelBindings(); + Map<String, Integer> bindings2 = m2.getRowLabelBindings(); + assertEquals(bindings == null, bindings2 == null); + if (bindings != null) { + assertEquals(bindings.size(), m.numRows()); + assertEquals(bindings.size(), bindings2.size()); + for (Map.Entry<String,Integer> entry : bindings.entrySet()) { + assertEquals(entry.getValue(), bindings2.get(entry.getKey())); + } + } + bindings = m.getColumnLabelBindings(); + bindings2 = m2.getColumnLabelBindings(); + assertEquals(bindings == null, bindings2 == null); + if (bindings != null) { + assertEquals(bindings.size(), bindings2.size()); + for (Map.Entry<String,Integer> entry : bindings.entrySet()) { + assertEquals(entry.getValue(), bindings2.get(entry.getKey())); + } + } + } + + private static void doCheckBindings(Map<String,Integer> labels) { + assertTrue("Missing label", labels.keySet().contains("A")); + assertTrue("Missing label", labels.keySet().contains("B")); + assertTrue("Missing label", labels.keySet().contains("C")); + assertTrue("Missing label", labels.keySet().contains("D")); + assertTrue("Missing label", labels.keySet().contains("default")); + } + + private static void writeAndRead(Writable toWrite, Writable toRead) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(baos); + try { + toWrite.write(dos); + } finally { + Closeables.close(dos, false); + } + + ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); + DataInputStream dis = new DataInputStream(bais); + try { + toRead.readFields(dis); + } finally { + Closeables.close(dis, true); + } + } + + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/math/VarintTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/math/VarintTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/math/VarintTest.java new file mode 100644 index 0000000..0b1a664 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/math/VarintTest.java @@ -0,0 +1,189 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math; + +import org.junit.Test; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInput; +import java.io.DataInputStream; +import java.io.DataOutput; +import java.io.DataOutputStream; + +/** + * Tests {@link Varint}. + */ +public final class VarintTest extends MahoutTestCase { + + @Test + public void testUnsignedLong() throws Exception { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutput out = new DataOutputStream(baos); + Varint.writeUnsignedVarLong(0L, out); + for (long i = 1L; i > 0L && i <= (1L << 62); i <<= 1) { + Varint.writeUnsignedVarLong(i-1, out); + Varint.writeUnsignedVarLong(i, out); + } + Varint.writeUnsignedVarLong(Long.MAX_VALUE, out); + + DataInput in = new DataInputStream(new ByteArrayInputStream(baos.toByteArray())); + assertEquals(0L, Varint.readUnsignedVarLong(in)); + for (long i = 1L; i > 0L && i <= (1L << 62); i <<= 1) { + assertEquals(i-1, Varint.readUnsignedVarLong(in)); + assertEquals(i, Varint.readUnsignedVarLong(in)); + } + assertEquals(Long.MAX_VALUE, Varint.readUnsignedVarLong(in)); + } + + @Test + public void testSignedPositiveLong() throws Exception { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutput out = new DataOutputStream(baos); + Varint.writeSignedVarLong(0L, out); + for (long i = 1L; i <= (1L << 61); i <<= 1) { + Varint.writeSignedVarLong(i-1, out); + Varint.writeSignedVarLong(i, out); + } + Varint.writeSignedVarLong((1L << 62) - 1, out); + Varint.writeSignedVarLong((1L << 62), out); + Varint.writeSignedVarLong(Long.MAX_VALUE, out); + + DataInput in = new DataInputStream(new ByteArrayInputStream(baos.toByteArray())); + assertEquals(0L, Varint.readSignedVarLong(in)); + for (long i = 1L; i <= (1L << 61); i <<= 1) { + assertEquals(i-1, Varint.readSignedVarLong(in)); + assertEquals(i, Varint.readSignedVarLong(in)); + } + assertEquals((1L << 62) - 1, Varint.readSignedVarLong(in)); + assertEquals((1L << 62), Varint.readSignedVarLong(in)); + assertEquals(Long.MAX_VALUE, Varint.readSignedVarLong(in)); + } + + @Test + public void testSignedNegativeLong() throws Exception { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutput out = new DataOutputStream(baos); + for (long i = -1L; i >= -(1L << 62); i <<= 1) { + Varint.writeSignedVarLong(i, out); + Varint.writeSignedVarLong(i+1, out); + } + Varint.writeSignedVarLong(Long.MIN_VALUE, out); + Varint.writeSignedVarLong(Long.MIN_VALUE+1, out); + DataInput in = new DataInputStream(new ByteArrayInputStream(baos.toByteArray())); + for (long i = -1L; i >= -(1L << 62); i <<= 1) { + assertEquals(i, Varint.readSignedVarLong(in)); + assertEquals(i+1, Varint.readSignedVarLong(in)); + } + assertEquals(Long.MIN_VALUE, Varint.readSignedVarLong(in)); + assertEquals(Long.MIN_VALUE+1, Varint.readSignedVarLong(in)); + } + + @Test + public void testUnsignedInt() throws Exception { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutput out = new DataOutputStream(baos); + Varint.writeUnsignedVarInt(0, out); + for (int i = 1; i > 0 && i <= (1 << 30); i <<= 1) { + Varint.writeUnsignedVarLong(i-1, out); + Varint.writeUnsignedVarLong(i, out); + } + Varint.writeUnsignedVarLong(Integer.MAX_VALUE, out); + + DataInput in = new DataInputStream(new ByteArrayInputStream(baos.toByteArray())); + assertEquals(0, Varint.readUnsignedVarInt(in)); + for (int i = 1; i > 0 && i <= (1 << 30); i <<= 1) { + assertEquals(i-1, Varint.readUnsignedVarInt(in)); + assertEquals(i, Varint.readUnsignedVarInt(in)); + } + assertEquals(Integer.MAX_VALUE, Varint.readUnsignedVarInt(in)); + } + + @Test + public void testSignedPositiveInt() throws Exception { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutput out = new DataOutputStream(baos); + Varint.writeSignedVarInt(0, out); + for (int i = 1; i <= (1 << 29); i <<= 1) { + Varint.writeSignedVarLong(i-1, out); + Varint.writeSignedVarLong(i, out); + } + Varint.writeSignedVarInt((1 << 30) - 1, out); + Varint.writeSignedVarInt((1 << 30), out); + Varint.writeSignedVarInt(Integer.MAX_VALUE, out); + + DataInput in = new DataInputStream(new ByteArrayInputStream(baos.toByteArray())); + assertEquals(0, Varint.readSignedVarInt(in)); + for (int i = 1; i <= (1 << 29); i <<= 1) { + assertEquals(i-1, Varint.readSignedVarInt(in)); + assertEquals(i, Varint.readSignedVarInt(in)); + } + assertEquals((1L << 30) - 1, Varint.readSignedVarInt(in)); + assertEquals((1L << 30), Varint.readSignedVarInt(in)); + assertEquals(Integer.MAX_VALUE, Varint.readSignedVarInt(in)); + } + + @Test + public void testSignedNegativeInt() throws Exception { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutput out = new DataOutputStream(baos); + for (int i = -1; i >= -(1 << 30); i <<= 1) { + Varint.writeSignedVarInt(i, out); + Varint.writeSignedVarInt(i+1, out); + } + Varint.writeSignedVarInt(Integer.MIN_VALUE, out); + Varint.writeSignedVarInt(Integer.MIN_VALUE+1, out); + DataInput in = new DataInputStream(new ByteArrayInputStream(baos.toByteArray())); + for (int i = -1; i >= -(1 << 30); i <<= 1) { + assertEquals(i, Varint.readSignedVarInt(in)); + assertEquals(i+1, Varint.readSignedVarInt(in)); + } + assertEquals(Integer.MIN_VALUE, Varint.readSignedVarInt(in)); + assertEquals(Integer.MIN_VALUE+1, Varint.readSignedVarInt(in)); + } + + @Test + public void testUnsignedSize() throws Exception { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutput out = new DataOutputStream(baos); + int expectedSize = 0; + for (int exponent = 0; exponent <= 62; exponent++) { + Varint.writeUnsignedVarLong(1L << exponent, out); + expectedSize += 1 + exponent / 7; + assertEquals(expectedSize, baos.size()); + } + } + + @Test + public void testSignedSize() throws Exception { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutput out = new DataOutputStream(baos); + int expectedSize = 0; + for (int exponent = 0; exponent <= 61; exponent++) { + Varint.writeSignedVarLong(1L << exponent, out); + expectedSize += 1 + ((exponent + 1) / 7); + assertEquals(expectedSize, baos.size()); + } + for (int exponent = 0; exponent <= 61; exponent++) { + Varint.writeSignedVarLong(-(1L << exponent)-1, out); + expectedSize += 1 + ((exponent + 1) / 7); + assertEquals(expectedSize, baos.size()); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/math/VectorWritableTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/math/VectorWritableTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/math/VectorWritableTest.java new file mode 100644 index 0000000..60fb8b4 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/math/VectorWritableTest.java @@ -0,0 +1,123 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license + * agreements. See the NOTICE file distributed with this work for additional information regarding + * copyright ownership. The ASF licenses this file to You under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package org.apache.mahout.math; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.math.Vector.Element; +import org.junit.Test; + +import com.carrotsearch.randomizedtesting.RandomizedTest; +import com.carrotsearch.randomizedtesting.annotations.Repeat; +import com.google.common.io.Closeables; + +public final class VectorWritableTest extends RandomizedTest { + private static final int MAX_VECTOR_SIZE = 100; + + public void createRandom(Vector v) { + int size = randomInt(v.size() - 1); + for (int i = 0; i < size; ++i) { + v.set(randomInt(v.size() - 1), randomDouble()); + } + + int zeros = Math.max(2, size / 4); + for (Element e : v.nonZeroes()) { + if (e.index() % zeros == 0) { + e.set(0.0); + } + } + } + + @Test + @Repeat(iterations = 20) + public void testViewSequentialAccessSparseVectorWritable() throws Exception { + Vector v = new SequentialAccessSparseVector(MAX_VECTOR_SIZE); + createRandom(v); + Vector view = new VectorView(v, 0, v.size()); + doTestVectorWritableEquals(view); + } + + @Test + @Repeat(iterations = 20) + public void testSequentialAccessSparseVectorWritable() throws Exception { + Vector v = new SequentialAccessSparseVector(MAX_VECTOR_SIZE); + createRandom(v); + doTestVectorWritableEquals(v); + } + + @Test + @Repeat(iterations = 20) + public void testRandomAccessSparseVectorWritable() throws Exception { + Vector v = new RandomAccessSparseVector(MAX_VECTOR_SIZE); + createRandom(v); + doTestVectorWritableEquals(v); + } + + @Test + @Repeat(iterations = 20) + public void testDenseVectorWritable() throws Exception { + Vector v = new DenseVector(MAX_VECTOR_SIZE); + createRandom(v); + doTestVectorWritableEquals(v); + } + + @Test + @Repeat(iterations = 20) + public void testNamedVectorWritable() throws Exception { + Vector v = new DenseVector(MAX_VECTOR_SIZE); + v = new NamedVector(v, "Victor"); + createRandom(v); + doTestVectorWritableEquals(v); + } + + private static void doTestVectorWritableEquals(Vector v) throws IOException { + Writable vectorWritable = new VectorWritable(v); + VectorWritable vectorWritable2 = new VectorWritable(); + writeAndRead(vectorWritable, vectorWritable2); + Vector v2 = vectorWritable2.get(); + if (v instanceof NamedVector) { + assertTrue(v2 instanceof NamedVector); + NamedVector nv = (NamedVector) v; + NamedVector nv2 = (NamedVector) v2; + assertEquals(nv.getName(), nv2.getName()); + assertEquals("Victor", nv.getName()); + } + assertEquals(v, v2); + } + + private static void writeAndRead(Writable toWrite, Writable toRead) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(baos); + try { + toWrite.write(dos); + } finally { + Closeables.close(dos, false); + } + + ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); + DataInputStream dis = new DataInputStream(bais); + try { + toRead.readFields(dis); + } finally { + Closeables.close(dos, true); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/math/hadoop/MathHelper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/math/hadoop/MathHelper.java b/community/mahout-mr/src/test/java/org/apache/mahout/math/hadoop/MathHelper.java new file mode 100644 index 0000000..082c162 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/math/hadoop/MathHelper.java @@ -0,0 +1,236 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.hadoop; + +import com.google.common.io.Closeables; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.Vector.Element; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.map.OpenIntObjectHashMap; +import org.easymock.EasyMock; +import org.easymock.IArgumentMatcher; +import org.junit.Assert; + +import java.io.IOException; +import java.text.DecimalFormat; +import java.text.DecimalFormatSymbols; +import java.util.Locale; + +/** + * a collection of small helper methods useful for unit-testing mathematical operations + */ +public final class MathHelper { + + private MathHelper() {} + + /** + * convenience method to create a {@link Vector.Element} + */ + public static Vector.Element elem(int index, double value) { + return new ElementToCheck(index, value); + } + + /** + * a simple implementation of {@link Vector.Element} + */ + static class ElementToCheck implements Vector.Element { + private final int index; + private double value; + + ElementToCheck(int index, double value) { + this.index = index; + this.value = value; + } + @Override + public double get() { + return value; + } + @Override + public int index() { + return index; + } + @Override + public void set(double value) { + this.value = value; + } + } + + /** + * applies an {@link IArgumentMatcher} to a {@link VectorWritable} that checks whether all elements are included + */ + public static VectorWritable vectorMatches(final Vector.Element... elements) { + EasyMock.reportMatcher(new IArgumentMatcher() { + @Override + public boolean matches(Object argument) { + if (argument instanceof VectorWritable) { + Vector v = ((VectorWritable) argument).get(); + return consistsOf(v, elements); + } + return false; + } + + @Override + public void appendTo(StringBuffer buffer) {} + }); + return null; + } + + /** + * checks whether the {@link Vector} is equivalent to the set of {@link Vector.Element}s + */ + public static boolean consistsOf(Vector vector, Vector.Element... elements) { + if (elements.length != numberOfNoNZeroNonNaNElements(vector)) { + return false; + } + for (Vector.Element element : elements) { + if (Math.abs(element.get() - vector.get(element.index())) > MahoutTestCase.EPSILON) { + return false; + } + } + return true; + } + + /** + * returns the number of elements in the {@link Vector} that are neither 0 nor NaN + */ + public static int numberOfNoNZeroNonNaNElements(Vector vector) { + int elementsInVector = 0; + for (Element currentElement : vector.nonZeroes()) { + if (!Double.isNaN(currentElement.get())) { + elementsInVector++; + } + } + return elementsInVector; + } + + /** + * read a {@link Matrix} from a SequenceFile<IntWritable,VectorWritable> + */ + public static Matrix readMatrix(Configuration conf, Path path, int rows, int columns) { + boolean readOneRow = false; + Matrix matrix = new DenseMatrix(rows, columns); + for (Pair<IntWritable,VectorWritable> record : + new SequenceFileIterable<IntWritable,VectorWritable>(path, true, conf)) { + IntWritable key = record.getFirst(); + VectorWritable value = record.getSecond(); + readOneRow = true; + int row = key.get(); + for (Element element : value.get().nonZeroes()) { + matrix.set(row, element.index(), element.get()); + } + } + if (!readOneRow) { + throw new IllegalStateException("Not a single row read!"); + } + return matrix; + } + + /** + * read a {@link Matrix} from a SequenceFile<IntWritable,VectorWritable> + */ + public static OpenIntObjectHashMap<Vector> readMatrixRows(Configuration conf, Path path) { + boolean readOneRow = false; + OpenIntObjectHashMap<Vector> rows = new OpenIntObjectHashMap<>(); + for (Pair<IntWritable,VectorWritable> record : + new SequenceFileIterable<IntWritable,VectorWritable>(path, true, conf)) { + IntWritable key = record.getFirst(); + readOneRow = true; + rows.put(key.get(), record.getSecond().get()); + } + if (!readOneRow) { + throw new IllegalStateException("Not a single row read!"); + } + return rows; + } + + /** + * write a two-dimensional double array to an SequenceFile<IntWritable,VectorWritable> + */ + public static void writeDistributedRowMatrix(double[][] entries, FileSystem fs, Configuration conf, Path path) + throws IOException { + SequenceFile.Writer writer = null; + try { + writer = new SequenceFile.Writer(fs, conf, path, IntWritable.class, VectorWritable.class); + for (int n = 0; n < entries.length; n++) { + Vector v = new RandomAccessSparseVector(entries[n].length); + for (int m = 0; m < entries[n].length; m++) { + v.setQuick(m, entries[n][m]); + } + writer.append(new IntWritable(n), new VectorWritable(v)); + } + } finally { + Closeables.close(writer, false); + } + } + + public static void assertMatrixEquals(Matrix expected, Matrix actual) { + Assert.assertEquals(expected.numRows(), actual.numRows()); + Assert.assertEquals(actual.numCols(), actual.numCols()); + for (int row = 0; row < expected.numRows(); row++) { + for (int col = 0; col < expected.numCols(); col ++) { + Assert.assertEquals("Non-matching values in [" + row + ',' + col + ']', + expected.get(row, col), actual.get(row, col), MahoutTestCase.EPSILON); + } + } + } + + public static String nice(Vector v) { + if (!v.isSequentialAccess()) { + v = new DenseVector(v); + } + + DecimalFormat df = new DecimalFormat("0.00", DecimalFormatSymbols.getInstance(Locale.ENGLISH)); + + StringBuilder buffer = new StringBuilder("["); + String separator = ""; + for (Vector.Element e : v.all()) { + buffer.append(separator); + if (Double.isNaN(e.get())) { + buffer.append(" - "); + } else { + if (e.get() >= 0) { + buffer.append(' '); + } + buffer.append(df.format(e.get())); + } + separator = "\t"; + } + buffer.append(" ]"); + return buffer.toString(); + } + + public static String nice(Matrix matrix) { + StringBuilder info = new StringBuilder(); + for (int n = 0; n < matrix.numRows(); n++) { + info.append(nice(matrix.viewRow(n))).append('\n'); + } + return info.toString(); + } +}
