http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorTest.java b/mr/src/test/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorTest.java new file mode 100644 index 0000000..dbb950a --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorTest.java @@ -0,0 +1,205 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.mahout.clustering.topdown.postprocessor; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +import org.apache.commons.lang3.ArrayUtils; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.FileUtil; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.clustering.ClusteringTestUtils; +import org.apache.mahout.clustering.canopy.CanopyDriver; +import org.apache.mahout.clustering.classify.WeightedVectorWritable; +import org.apache.mahout.clustering.topdown.PathDirectory; +import org.apache.mahout.common.DummyOutputCollector; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.distance.ManhattanDistanceMeasure; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import com.google.common.collect.Lists; + +public final class ClusterOutputPostProcessorTest extends MahoutTestCase { + + private static final double[][] REFERENCE = { {1, 1}, {2, 1}, {1, 2}, {4, 4}, {5, 4}, {4, 5}, {5, 5}}; + + private FileSystem fs; + + private Path outputPath; + + private Configuration conf; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + Configuration conf = getConfiguration(); + fs = FileSystem.get(conf); + } + + private static List<VectorWritable> getPointsWritable(double[][] raw) { + List<VectorWritable> points = Lists.newArrayList(); + for (double[] fr : raw) { + Vector vec = new RandomAccessSparseVector(fr.length); + vec.assign(fr); + points.add(new VectorWritable(vec)); + } + return points; + } + + /** + * Story: User wants to use cluster post processor after canopy clustering and then run clustering on the + * output clusters + */ + @Test + public void testTopDownClustering() throws Exception { + List<VectorWritable> points = getPointsWritable(REFERENCE); + + Path pointsPath = getTestTempDirPath("points"); + conf = getConfiguration(); + ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath, "file1"), fs, conf); + ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath, "file2"), fs, conf); + + outputPath = getTestTempDirPath("output"); + + topLevelClustering(pointsPath, conf); + + Map<String,Path> postProcessedClusterDirectories = ouputPostProcessing(conf); + + assertPostProcessedOutput(postProcessedClusterDirectories); + + bottomLevelClustering(postProcessedClusterDirectories); + } + + private void assertTopLevelCluster(Entry<String,Path> cluster) { + String clusterId = cluster.getKey(); + Path clusterPath = cluster.getValue(); + + try { + if ("0".equals(clusterId)) { + assertPointsInFirstTopLevelCluster(clusterPath); + } else if ("1".equals(clusterId)) { + assertPointsInSecondTopLevelCluster(clusterPath); + } + } catch (IOException e) { + Assert.fail("Exception occurred while asserting top level cluster."); + } + + } + + private void assertPointsInFirstTopLevelCluster(Path clusterPath) throws IOException { + List<Vector> vectorsInCluster = getVectorsInCluster(clusterPath); + for (Vector vector : vectorsInCluster) { + Assert.assertTrue(ArrayUtils.contains(new String[] {"{0:1.0,1:1.0}", "{0:2.0,1:1.0}", "{0:1.0,1:2.0}"}, + vector.asFormatString())); + } + } + + private void assertPointsInSecondTopLevelCluster(Path clusterPath) throws IOException { + List<Vector> vectorsInCluster = getVectorsInCluster(clusterPath); + for (Vector vector : vectorsInCluster) { + Assert.assertTrue(ArrayUtils.contains(new String[] {"{0:4.0,1:4.0}", "{0:5.0,1:4.0}", "{0:4.0,1:5.0}", + "{0:5.0,1:5.0}"}, vector.asFormatString())); + } + } + + private List<Vector> getVectorsInCluster(Path clusterPath) throws IOException { + Path[] partFilePaths = FileUtil.stat2Paths(fs.globStatus(clusterPath)); + FileStatus[] listStatus = fs.listStatus(partFilePaths); + List<Vector> vectors = Lists.newArrayList(); + for (FileStatus partFile : listStatus) { + SequenceFile.Reader topLevelClusterReader = new SequenceFile.Reader(fs, partFile.getPath(), conf); + Writable clusterIdAsKey = new LongWritable(); + VectorWritable point = new VectorWritable(); + while (topLevelClusterReader.next(clusterIdAsKey, point)) { + vectors.add(point.get()); + } + } + return vectors; + } + + private void bottomLevelClustering(Map<String,Path> postProcessedClusterDirectories) throws IOException, + InterruptedException, + ClassNotFoundException { + for (Entry<String,Path> topLevelCluster : postProcessedClusterDirectories.entrySet()) { + String clusterId = topLevelCluster.getKey(); + Path topLevelclusterPath = topLevelCluster.getValue(); + + Path bottomLevelCluster = PathDirectory.getBottomLevelClusterPath(outputPath, clusterId); + CanopyDriver.run(conf, topLevelclusterPath, bottomLevelCluster, new ManhattanDistanceMeasure(), 2.1, + 2.0, true, 0.0, true); + assertBottomLevelCluster(bottomLevelCluster); + } + } + + private void assertBottomLevelCluster(Path bottomLevelCluster) { + Path clusteredPointsPath = new Path(bottomLevelCluster, "clusteredPoints"); + + DummyOutputCollector<IntWritable,WeightedVectorWritable> collector = + new DummyOutputCollector<IntWritable,WeightedVectorWritable>(); + + // The key is the clusterId, the value is the weighted vector + for (Pair<IntWritable,WeightedVectorWritable> record : + new SequenceFileIterable<IntWritable,WeightedVectorWritable>(new Path(clusteredPointsPath, "part-m-0"), + conf)) { + collector.collect(record.getFirst(), record.getSecond()); + } + int clusterSize = collector.getKeys().size(); + // First top level cluster produces two more clusters, second top level cluster is not broken again + assertTrue(clusterSize == 1 || clusterSize == 2); + + } + + private void assertPostProcessedOutput(Map<String,Path> postProcessedClusterDirectories) { + for (Entry<String,Path> cluster : postProcessedClusterDirectories.entrySet()) { + assertTopLevelCluster(cluster); + } + } + + private Map<String,Path> ouputPostProcessing(Configuration conf) throws IOException { + ClusterOutputPostProcessor clusterOutputPostProcessor = new ClusterOutputPostProcessor(outputPath, + outputPath, conf); + clusterOutputPostProcessor.process(); + return clusterOutputPostProcessor.getPostProcessedClusterDirectories(); + } + + private void topLevelClustering(Path pointsPath, Configuration conf) throws IOException, + InterruptedException, + ClassNotFoundException { + CanopyDriver.run(conf, pointsPath, outputPath, new ManhattanDistanceMeasure(), 3.1, 2.1, true, 0.0, true); + } + +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/AbstractJobTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/AbstractJobTest.java b/mr/src/test/java/org/apache/mahout/common/AbstractJobTest.java new file mode 100644 index 0000000..7683b57 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/AbstractJobTest.java @@ -0,0 +1,240 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.mahout.common; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import com.google.common.collect.Maps; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.junit.Test; + +public final class AbstractJobTest extends MahoutTestCase { + + interface AbstractJobFactory { + AbstractJob getJob(); + } + + @Test + public void testFlag() throws Exception { + final Map<String,List<String>> testMap = Maps.newHashMap(); + + AbstractJobFactory fact = new AbstractJobFactory() { + @Override + public AbstractJob getJob() { + return new AbstractJob() { + @Override + public int run(String[] args) throws IOException { + addFlag("testFlag", "t", "a simple test flag"); + + Map<String,List<String>> argMap = parseArguments(args); + testMap.clear(); + testMap.putAll(argMap); + return 1; + } + }; + } + }; + + // testFlag will only be present if specified on the command-line + + ToolRunner.run(fact.getJob(), new String[0]); + assertFalse("test map for absent flag", testMap.containsKey("--testFlag")); + + String[] withFlag = { "--testFlag" }; + ToolRunner.run(fact.getJob(), withFlag); + assertTrue("test map for present flag", testMap.containsKey("--testFlag")); + } + + @Test + public void testOptions() throws Exception { + final Map<String,List<String>> testMap = Maps.newHashMap(); + + AbstractJobFactory fact = new AbstractJobFactory() { + @Override + public AbstractJob getJob() { + return new AbstractJob() { + @Override + public int run(String[] args) throws IOException { + this.addOption(DefaultOptionCreator.overwriteOption().create()); + this.addOption("option", "o", "option"); + this.addOption("required", "r", "required", true /* required */); + this.addOption("notRequired", "nr", "not required", false /* not required */); + this.addOption("hasDefault", "hd", "option w/ default", "defaultValue"); + + + Map<String,List<String>> argMap = parseArguments(args); + if (argMap == null) { + return -1; + } + + testMap.clear(); + testMap.putAll(argMap); + + return 0; + } + }; + } + }; + + int ret = ToolRunner.run(fact.getJob(), new String[0]); + assertEquals("-1 for missing required options", -1, ret); + + ret = ToolRunner.run(fact.getJob(), new String[]{ + "--required", "requiredArg" + }); + assertEquals("0 for no missing required options", 0, ret); + assertEquals(Collections.singletonList("requiredArg"), testMap.get("--required")); + assertEquals(Collections.singletonList("defaultValue"), testMap.get("--hasDefault")); + assertNull(testMap.get("--option")); + assertNull(testMap.get("--notRequired")); + assertFalse(testMap.containsKey("--overwrite")); + + ret = ToolRunner.run(fact.getJob(), new String[]{ + "--required", "requiredArg", + "--unknownArg" + }); + assertEquals("-1 for including unknown options", -1, ret); + + ret = ToolRunner.run(fact.getJob(), new String[]{ + "--required", "requiredArg", + "--required", "requiredArg2", + }); + assertEquals("-1 for including duplicate options", -1, ret); + + ret = ToolRunner.run(fact.getJob(), new String[]{ + "--required", "requiredArg", + "--overwrite", + "--hasDefault", "nonDefault", + "--option", "optionValue", + "--notRequired", "notRequired" + }); + assertEquals("0 for no missing required options", 0, ret); + assertEquals(Collections.singletonList("requiredArg"), testMap.get("--required")); + assertEquals(Collections.singletonList("nonDefault"), testMap.get("--hasDefault")); + assertEquals(Collections.singletonList("optionValue"), testMap.get("--option")); + assertEquals(Collections.singletonList("notRequired"), testMap.get("--notRequired")); + assertTrue(testMap.containsKey("--overwrite")); + + ret = ToolRunner.run(fact.getJob(), new String[]{ + "-r", "requiredArg", + "-ow", + "-hd", "nonDefault", + "-o", "optionValue", + "-nr", "notRequired" + }); + assertEquals("0 for no missing required options", 0, ret); + assertEquals(Collections.singletonList("requiredArg"), testMap.get("--required")); + assertEquals(Collections.singletonList("nonDefault"), testMap.get("--hasDefault")); + assertEquals(Collections.singletonList("optionValue"), testMap.get("--option")); + assertEquals(Collections.singletonList("notRequired"), testMap.get("--notRequired")); + assertTrue(testMap.containsKey("--overwrite")); + + } + + @Test + public void testInputOutputPaths() throws Exception { + + AbstractJobFactory fact = new AbstractJobFactory() { + @Override + public AbstractJob getJob() { + return new AbstractJob() { + @Override + public int run(String[] args) throws IOException { + addInputOption(); + addOutputOption(); + + // arg map should be null if a required option is missing. + Map<String, List<String>> argMap = parseArguments(args); + + if (argMap == null) { + return -1; + } + + Path inputPath = getInputPath(); + assertNotNull("getInputPath() returns non-null", inputPath); + + Path outputPath = getInputPath(); + assertNotNull("getOutputPath() returns non-null", outputPath); + return 0; + } + }; + } + }; + + int ret = ToolRunner.run(fact.getJob(), new String[0]); + assertEquals("-1 for missing input option", -1, ret); + + String testInputPath = "testInputPath"; + + AbstractJob job = fact.getJob(); + ret = ToolRunner.run(job, new String[]{ + "--input", testInputPath }); + assertEquals("-1 for missing output option", -1, ret); + assertEquals("input path is correct", testInputPath, job.getInputPath().toString()); + + job = fact.getJob(); + String testOutputPath = "testOutputPath"; + ret = ToolRunner.run(job, new String[]{ + "--output", testOutputPath }); + assertEquals("-1 for missing input option", -1, ret); + assertEquals("output path is correct", testOutputPath, job.getOutputPath().toString()); + + job = fact.getJob(); + ret = ToolRunner.run(job, new String[]{ + "--input", testInputPath, "--output", testOutputPath }); + assertEquals("0 for complete options", 0, ret); + assertEquals("input path is correct", testInputPath, job.getInputPath().toString()); + assertEquals("output path is correct", testOutputPath, job.getOutputPath().toString()); + + job = fact.getJob(); + ret = ToolRunner.run(job, new String[]{ + "--input", testInputPath, "--output", testOutputPath }); + assertEquals("0 for complete options", 0, ret); + assertEquals("input path is correct", testInputPath, job.getInputPath().toString()); + assertEquals("output path is correct", testOutputPath, job.getOutputPath().toString()); + + job = fact.getJob(); + String testInputPropertyPath = "testInputPropertyPath"; + String testOutputPropertyPath = "testOutputPropertyPath"; + ret = ToolRunner.run(job, new String[]{ + "-Dmapred.input.dir=" + testInputPropertyPath, + "-Dmapred.output.dir=" + testOutputPropertyPath }); + assertEquals("0 for complete options", 0, ret); + assertEquals("input path from property is correct", testInputPropertyPath, job.getInputPath().toString()); + assertEquals("output path from property is correct", testOutputPropertyPath, job.getOutputPath().toString()); + + job = fact.getJob(); + ret = ToolRunner.run(job, new String[]{ + "-Dmapred.input.dir=" + testInputPropertyPath, + "-Dmapred.output.dir=" + testOutputPropertyPath, + "--input", testInputPath, + "--output", testOutputPath }); + assertEquals("0 for complete options", 0, ret); + assertEquals("input command-line option precedes property", + testInputPath, job.getInputPath().toString()); + assertEquals("output command-line option precedes property", + testOutputPath, job.getOutputPath().toString()); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/DistributedCacheFileLocationTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/DistributedCacheFileLocationTest.java b/mr/src/test/java/org/apache/mahout/common/DistributedCacheFileLocationTest.java new file mode 100644 index 0000000..5d3532c --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/DistributedCacheFileLocationTest.java @@ -0,0 +1,46 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common; + +import org.apache.hadoop.fs.Path; +import org.junit.Test; + +import java.io.File; +import java.net.URI; + + +public class DistributedCacheFileLocationTest extends MahoutTestCase { + + static final File FILE_I_WANT_TO_FIND = new File("file/i_want_to_find.txt"); + static final URI[] DISTRIBUTED_CACHE_FILES = new URI[] { + new File("/first/file").toURI(), new File("/second/file").toURI(), FILE_I_WANT_TO_FIND.toURI() }; + + @Test + public void nonExistingFile() { + Path path = HadoopUtil.findInCacheByPartOfFilename("no such file", DISTRIBUTED_CACHE_FILES); + assertNull(path); + } + + @Test + public void existingFile() { + Path path = HadoopUtil.findInCacheByPartOfFilename("want_to_find", DISTRIBUTED_CACHE_FILES); + assertNotNull(path); + assertEquals(FILE_I_WANT_TO_FIND.getName(), path.getName()); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/DummyOutputCollector.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/DummyOutputCollector.java b/mr/src/test/java/org/apache/mahout/common/DummyOutputCollector.java new file mode 100644 index 0000000..6951f5a --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/DummyOutputCollector.java @@ -0,0 +1,57 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common; + +import com.google.common.collect.Lists; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.io.WritableComparable; +import org.apache.hadoop.mapred.OutputCollector; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; + +public final class DummyOutputCollector<K extends WritableComparable, V extends Writable> + implements OutputCollector<K,V> { + + private final Map<K, List<V>> data = new TreeMap<K,List<V>>(); + + @Override + public void collect(K key,V values) { + List<V> points = data.get(key); + if (points == null) { + points = Lists.newArrayList(); + data.put(key, points); + } + points.add(values); + } + + public Map<K,List<V>> getData() { + return data; + } + + public List<V> getValue(K key) { + return data.get(key); + } + + public Set<K> getKeys() { + return data.keySet(); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/DummyRecordWriter.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/DummyRecordWriter.java b/mr/src/test/java/org/apache/mahout/common/DummyRecordWriter.java new file mode 100644 index 0000000..7dea174 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/DummyRecordWriter.java @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapreduce.MapContext; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.RecordWriter; +import org.apache.hadoop.mapreduce.ReduceContext; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.hadoop.mapreduce.TaskAttemptID; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.lang.reflect.Constructor; +import java.lang.reflect.Method; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public final class DummyRecordWriter<K extends Writable, V extends Writable> extends RecordWriter<K, V> { + + private final List<K> keysInInsertionOrder = Lists.newArrayList(); + private final Map<K, List<V>> data = Maps.newHashMap(); + + @Override + public void write(K key, V value) { + + // if the user reuses the same writable class, we need to create a new one + // otherwise the Map content will be modified after the insert + try { + + K keyToUse = key instanceof NullWritable ? key : (K) cloneWritable(key); + V valueToUse = (V) cloneWritable(value); + + keysInInsertionOrder.add(keyToUse); + + List<V> points = data.get(key); + if (points == null) { + points = Lists.newArrayList(); + data.put(keyToUse, points); + } + points.add(valueToUse); + + } catch (IOException e) { + throw new RuntimeException(e.getMessage(), e); + } + } + + private Writable cloneWritable(Writable original) throws IOException { + + Writable clone; + try { + clone = original.getClass().asSubclass(Writable.class).newInstance(); + } catch (Exception e) { + throw new RuntimeException("Unable to instantiate writable!", e); + } + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + + original.write(new DataOutputStream(bytes)); + clone.readFields(new DataInputStream(new ByteArrayInputStream(bytes.toByteArray()))); + + return clone; + } + + @Override + public void close(TaskAttemptContext context) { + } + + public Map<K, List<V>> getData() { + return data; + } + + public List<V> getValue(K key) { + return data.get(key); + } + + public Set<K> getKeys() { + return data.keySet(); + } + + public Iterable<K> getKeysInInsertionOrder() { + return keysInInsertionOrder; + } + + public static <K1, V1, K2, V2> Mapper<K1, V1, K2, V2>.Context build(Mapper<K1, V1, K2, V2> mapper, + Configuration configuration, + RecordWriter<K2, V2> output) { + + // Use reflection since the context types changed incompatibly between 0.20 + // and 0.23. + try { + return buildNewMapperContext(configuration, output); + } catch (Exception e) { + try { + return buildOldMapperContext(mapper, configuration, output); + } catch (Exception ex) { + throw new IllegalStateException(ex); + } + } + } + + public static <K1, V1, K2, V2> Reducer<K1, V1, K2, V2>.Context build(Reducer<K1, V1, K2, V2> reducer, + Configuration configuration, + RecordWriter<K2, V2> output, + Class<K1> keyClass, + Class<V1> valueClass) { + + // Use reflection since the context types changed incompatibly between 0.20 + // and 0.23. + try { + return buildNewReducerContext(configuration, output, keyClass, valueClass); + } catch (Exception e) { + try { + return buildOldReducerContext(reducer, configuration, output, keyClass, valueClass); + } catch (Exception ex) { + throw new IllegalStateException(ex); + } + } + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private static <K1, V1, K2, V2> Mapper<K1, V1, K2, V2>.Context buildNewMapperContext( + Configuration configuration, RecordWriter<K2, V2> output) throws Exception { + Class<?> mapContextImplClass = Class.forName("org.apache.hadoop.mapreduce.task.MapContextImpl"); + Constructor<?> cons = mapContextImplClass.getConstructors()[0]; + Object mapContextImpl = cons.newInstance(configuration, + new TaskAttemptID(), null, output, null, new DummyStatusReporter(), null); + + Class<?> wrappedMapperClass = Class.forName("org.apache.hadoop.mapreduce.lib.map.WrappedMapper"); + Object wrappedMapper = wrappedMapperClass.getConstructor().newInstance(); + Method getMapContext = wrappedMapperClass.getMethod("getMapContext", MapContext.class); + return (Mapper.Context) getMapContext.invoke(wrappedMapper, mapContextImpl); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private static <K1, V1, K2, V2> Mapper<K1, V1, K2, V2>.Context buildOldMapperContext( + Mapper<K1, V1, K2, V2> mapper, Configuration configuration, + RecordWriter<K2, V2> output) throws Exception { + Constructor<?> cons = getNestedContextConstructor(mapper.getClass()); + // first argument to the constructor is the enclosing instance + return (Mapper.Context) cons.newInstance(mapper, configuration, + new TaskAttemptID(), null, output, null, new DummyStatusReporter(), null); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private static <K1, V1, K2, V2> Reducer<K1, V1, K2, V2>.Context buildNewReducerContext( + Configuration configuration, RecordWriter<K2, V2> output, Class<K1> keyClass, + Class<V1> valueClass) throws Exception { + Class<?> reduceContextImplClass = Class.forName("org.apache.hadoop.mapreduce.task.ReduceContextImpl"); + Constructor<?> cons = reduceContextImplClass.getConstructors()[0]; + Object reduceContextImpl = cons.newInstance(configuration, + new TaskAttemptID(), + new MockIterator(), + null, + null, + output, + null, + new DummyStatusReporter(), + null, + keyClass, + valueClass); + + Class<?> wrappedReducerClass = Class.forName("org.apache.hadoop.mapreduce.lib.reduce.WrappedReducer"); + Object wrappedReducer = wrappedReducerClass.getConstructor().newInstance(); + Method getReducerContext = wrappedReducerClass.getMethod("getReducerContext", ReduceContext.class); + return (Reducer.Context) getReducerContext.invoke(wrappedReducer, reduceContextImpl); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private static <K1, V1, K2, V2> Reducer<K1, V1, K2, V2>.Context buildOldReducerContext( + Reducer<K1, V1, K2, V2> reducer, Configuration configuration, + RecordWriter<K2, V2> output, Class<K1> keyClass, + Class<V1> valueClass) throws Exception { + Constructor<?> cons = getNestedContextConstructor(reducer.getClass()); + // first argument to the constructor is the enclosing instance + return (Reducer.Context) cons.newInstance(reducer, + configuration, + new TaskAttemptID(), + new MockIterator(), + null, + null, + output, + null, + new DummyStatusReporter(), + null, + keyClass, + valueClass); + } + + private static Constructor<?> getNestedContextConstructor(Class<?> outerClass) { + for (Class<?> nestedClass : outerClass.getClasses()) { + if ("Context".equals(nestedClass.getSimpleName())) { + return nestedClass.getConstructors()[0]; + } + } + throw new IllegalStateException("Cannot find context class for " + outerClass); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/DummyRecordWriterTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/DummyRecordWriterTest.java b/mr/src/test/java/org/apache/mahout/common/DummyRecordWriterTest.java new file mode 100644 index 0000000..6b25448 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/DummyRecordWriterTest.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common; + +import org.apache.hadoop.io.IntWritable; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.VectorWritable; +import org.junit.Assert; +import org.junit.Test; + +public class DummyRecordWriterTest { + + @Test + public void testWrite() { + DummyRecordWriter<IntWritable, VectorWritable> writer = + new DummyRecordWriter<IntWritable, VectorWritable>(); + IntWritable reusableIntWritable = new IntWritable(); + VectorWritable reusableVectorWritable = new VectorWritable(); + reusableIntWritable.set(0); + reusableVectorWritable.set(new DenseVector(new double[] { 1, 2, 3 })); + writer.write(reusableIntWritable, reusableVectorWritable); + reusableIntWritable.set(1); + reusableVectorWritable.set(new DenseVector(new double[] { 4, 5, 6 })); + writer.write(reusableIntWritable, reusableVectorWritable); + + Assert.assertEquals( + "The writer must remember the two keys that is written to it", 2, + writer.getKeys().size()); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/DummyStatusReporter.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/DummyStatusReporter.java b/mr/src/test/java/org/apache/mahout/common/DummyStatusReporter.java new file mode 100644 index 0000000..c6bc34b --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/DummyStatusReporter.java @@ -0,0 +1,76 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.mahout.common; + +import org.easymock.EasyMock; + +import java.util.Map; + +import com.google.common.collect.Maps; +import org.apache.hadoop.mapreduce.Counter; +import org.apache.hadoop.mapreduce.StatusReporter; + +public final class DummyStatusReporter extends StatusReporter { + + private final Map<Enum<?>, Counter> counters = Maps.newHashMap(); + private final Map<String, Counter> counterGroups = Maps.newHashMap(); + + private static Counter newCounter() { + try { + // 0.23 case + String c = "org.apache.hadoop.mapreduce.counters.GenericCounter"; + return (Counter) EasyMock.createMockBuilder(Class.forName(c)).createMock(); + } catch (ClassNotFoundException e) { + // 0.20 case + return EasyMock.createMockBuilder(Counter.class).createMock(); + } + } + + @Override + public Counter getCounter(Enum<?> name) { + if (!counters.containsKey(name)) { + counters.put(name, newCounter()); + } + return counters.get(name); + } + + + @Override + public Counter getCounter(String group, String name) { + if (!counterGroups.containsKey(group + name)) { + counterGroups.put(group + name, newCounter()); + } + return counterGroups.get(group+name); + } + + @Override + public void progress() { + } + + @Override + public void setStatus(String status) { + } + + @Override + public float getProgress() { + return 0.0f; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/IntPairWritableTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/IntPairWritableTest.java b/mr/src/test/java/org/apache/mahout/common/IntPairWritableTest.java new file mode 100644 index 0000000..ceffe3e --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/IntPairWritableTest.java @@ -0,0 +1,114 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInput; +import java.io.DataInputStream; +import java.io.DataOutput; +import java.io.DataOutputStream; +import java.util.Arrays; + +import org.junit.Test; + +public final class IntPairWritableTest extends MahoutTestCase { + + @Test + public void testGetSet() { + IntPairWritable n = new IntPairWritable(); + + assertEquals(0, n.getFirst()); + assertEquals(0, n.getSecond()); + + n.setFirst(5); + n.setSecond(10); + + assertEquals(5, n.getFirst()); + assertEquals(10, n.getSecond()); + + n = new IntPairWritable(2,4); + + assertEquals(2, n.getFirst()); + assertEquals(4, n.getSecond()); + } + + @Test + public void testWritable() throws Exception { + IntPairWritable one = new IntPairWritable(1,2); + IntPairWritable two = new IntPairWritable(3,4); + + assertEquals(1, one.getFirst()); + assertEquals(2, one.getSecond()); + + assertEquals(3, two.getFirst()); + assertEquals(4, two.getSecond()); + + + ByteArrayOutputStream bout = new ByteArrayOutputStream(); + DataOutput out = new DataOutputStream(bout); + + two.write(out); + + byte[] b = bout.toByteArray(); + + ByteArrayInputStream bin = new ByteArrayInputStream(b); + DataInput din = new DataInputStream(bin); + + one.readFields(din); + + assertEquals(two.getFirst(), one.getFirst()); + assertEquals(two.getSecond(), one.getSecond()); + } + + @Test + public void testComparable() { + IntPairWritable[] input = { + new IntPairWritable(2,3), + new IntPairWritable(2,2), + new IntPairWritable(1,3), + new IntPairWritable(1,2), + new IntPairWritable(2,1), + new IntPairWritable(2,2), + new IntPairWritable(1,-2), + new IntPairWritable(1,-1), + new IntPairWritable(-2,-2), + new IntPairWritable(-2,-1), + new IntPairWritable(-1,-1), + new IntPairWritable(-1,-2), + new IntPairWritable(Integer.MAX_VALUE,1), + new IntPairWritable(Integer.MAX_VALUE/2,1), + new IntPairWritable(Integer.MIN_VALUE,1), + new IntPairWritable(Integer.MIN_VALUE/2,1) + + }; + + IntPairWritable[] sorted = new IntPairWritable[input.length]; + System.arraycopy(input, 0, sorted, 0, input.length); + Arrays.sort(sorted); + + int[] expected = { + 14, 15, 8, 9, 11, 10, 6, 7, 3, 2, 4, 1, 5, 0, 13, 12 + }; + + for (int i=0; i < input.length; i++) { + assertSame(input[expected[i]], sorted[i]); + } + + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/MahoutTestCase.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/MahoutTestCase.java b/mr/src/test/java/org/apache/mahout/common/MahoutTestCase.java new file mode 100644 index 0000000..775c8d8 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/MahoutTestCase.java @@ -0,0 +1,148 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.io.Writer; +import java.lang.reflect.Field; + +import com.google.common.base.Charsets; +import com.google.common.io.Closeables; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.junit.After; +import org.junit.Before; + +public class MahoutTestCase extends org.apache.mahout.math.MahoutTestCase { + + /** "Close enough" value for floating-point comparisons. */ + public static final double EPSILON = 0.000001; + + private Path testTempDirPath; + private FileSystem fs; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + RandomUtils.useTestSeed(); + testTempDirPath = null; + fs = null; + } + + @Override + @After + public void tearDown() throws Exception { + if (testTempDirPath != null) { + try { + fs.delete(testTempDirPath, true); + } catch (IOException e) { + throw new IllegalStateException("Test file not found"); + } + testTempDirPath = null; + fs = null; + } + super.tearDown(); + } + + public final Configuration getConfiguration() throws IOException { + Configuration conf = new Configuration(); + conf.set("hadoop.tmp.dir", getTestTempDir("hadoop" + Math.random()).getAbsolutePath()); + return conf; + } + + protected final Path getTestTempDirPath() throws IOException { + if (testTempDirPath == null) { + fs = FileSystem.get(getConfiguration()); + long simpleRandomLong = (long) (Long.MAX_VALUE * Math.random()); + testTempDirPath = fs.makeQualified( + new Path("/tmp/mahout-" + getClass().getSimpleName() + '-' + simpleRandomLong)); + if (!fs.mkdirs(testTempDirPath)) { + throw new IOException("Could not create " + testTempDirPath); + } + fs.deleteOnExit(testTempDirPath); + } + return testTempDirPath; + } + + protected final Path getTestTempFilePath(String name) throws IOException { + return getTestTempFileOrDirPath(name, false); + } + + protected final Path getTestTempDirPath(String name) throws IOException { + return getTestTempFileOrDirPath(name, true); + } + + private Path getTestTempFileOrDirPath(String name, boolean dir) throws IOException { + Path testTempDirPath = getTestTempDirPath(); + Path tempFileOrDir = fs.makeQualified(new Path(testTempDirPath, name)); + fs.deleteOnExit(tempFileOrDir); + if (dir && !fs.mkdirs(tempFileOrDir)) { + throw new IOException("Could not create " + tempFileOrDir); + } + return tempFileOrDir; + } + + /** + * Try to directly set a (possibly private) field on an Object + */ + protected static void setField(Object target, String fieldname, Object value) + throws NoSuchFieldException, IllegalAccessException { + Field field = findDeclaredField(target.getClass(), fieldname); + field.setAccessible(true); + field.set(target, value); + } + + /** + * Find a declared field in a class or one of it's super classes + */ + private static Field findDeclaredField(Class<?> inClass, String fieldname) throws NoSuchFieldException { + while (!Object.class.equals(inClass)) { + for (Field field : inClass.getDeclaredFields()) { + if (field.getName().equalsIgnoreCase(fieldname)) { + return field; + } + } + inClass = inClass.getSuperclass(); + } + throw new NoSuchFieldException(); + } + + /** + * @return a job option key string (--name) from the given option name + */ + protected static String optKey(String optionName) { + return AbstractJob.keyFor(optionName); + } + + protected static void writeLines(File file, String... lines) throws IOException { + Writer writer = new OutputStreamWriter(new FileOutputStream(file), Charsets.UTF_8); + try { + for (String line : lines) { + writer.write(line); + writer.write('\n'); + } + } finally { + Closeables.close(writer, false); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/MockIterator.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/MockIterator.java b/mr/src/test/java/org/apache/mahout/common/MockIterator.java new file mode 100644 index 0000000..ce48fdc --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/MockIterator.java @@ -0,0 +1,51 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common; + +import org.apache.hadoop.io.DataInputBuffer; +import org.apache.hadoop.mapred.RawKeyValueIterator; +import org.apache.hadoop.util.Progress; + +public final class MockIterator implements RawKeyValueIterator { + + @Override + public void close() { + } + + @Override + public DataInputBuffer getKey() { + return null; + } + + @Override + public Progress getProgress() { + return null; + } + + @Override + public DataInputBuffer getValue() { + + return null; + } + + @Override + public boolean next() { + return true; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/StringUtilsTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/StringUtilsTest.java b/mr/src/test/java/org/apache/mahout/common/StringUtilsTest.java new file mode 100644 index 0000000..0633685 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/StringUtilsTest.java @@ -0,0 +1,70 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common; + +import com.google.common.collect.Lists; +import org.junit.Test; + +import java.util.List; + +public final class StringUtilsTest extends MahoutTestCase { + + private static class DummyTest { + private int field; + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof DummyTest)) { + return false; + } + + DummyTest dt = (DummyTest) obj; + return field == dt.field; + } + + @Override + public int hashCode() { + return field; + } + + public int getField() { + return field; + } + } + + @Test + public void testStringConversion() throws Exception { + + List<String> expected = Lists.newArrayList("A", "B", "C"); + assertEquals(expected, StringUtils.fromString(StringUtils + .toString(expected))); + + // test a non serializable object + DummyTest test = new DummyTest(); + assertEquals(test, StringUtils.fromString(StringUtils.toString(test))); + } + + @Test + public void testEscape() throws Exception { + String res = StringUtils.escapeXML("\",\',&,>,<"); + assertEquals("_,_,_,_,_", res); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/distance/CosineDistanceMeasureTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/distance/CosineDistanceMeasureTest.java b/mr/src/test/java/org/apache/mahout/common/distance/CosineDistanceMeasureTest.java new file mode 100644 index 0000000..6db7c9b --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/distance/CosineDistanceMeasureTest.java @@ -0,0 +1,66 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +public final class CosineDistanceMeasureTest extends MahoutTestCase { + + @Test + public void testMeasure() { + + DistanceMeasure distanceMeasure = new CosineDistanceMeasure(); + + Vector[] vectors = { + new DenseVector(new double[]{1, 0, 0, 0, 0, 0}), + new DenseVector(new double[]{1, 1, 1, 0, 0, 0}), + new DenseVector(new double[]{1, 1, 1, 1, 1, 1}) + }; + + double[][] distanceMatrix = new double[3][3]; + + for (int a = 0; a < 3; a++) { + for (int b = 0; b < 3; b++) { + distanceMatrix[a][b] = distanceMeasure.distance(vectors[a], vectors[b]); + } + } + + assertEquals(0.0, distanceMatrix[0][0], EPSILON); + assertTrue(distanceMatrix[0][0] < distanceMatrix[0][1]); + assertTrue(distanceMatrix[0][1] < distanceMatrix[0][2]); + + assertEquals(0.0, distanceMatrix[1][1], EPSILON); + assertTrue(distanceMatrix[1][0] > distanceMatrix[1][1]); + assertTrue(distanceMatrix[1][2] < distanceMatrix[1][0]); + + assertEquals(0.0, distanceMatrix[2][2], EPSILON); + assertTrue(distanceMatrix[2][0] > distanceMatrix[2][1]); + assertTrue(distanceMatrix[2][1] > distanceMatrix[2][2]); + + // Two equal vectors (despite them being zero) should have 0 distance. + assertEquals(0, + distanceMeasure.distance(new SequentialAccessSparseVector(1), + new SequentialAccessSparseVector(1)), + EPSILON); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/distance/DefaultDistanceMeasureTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/distance/DefaultDistanceMeasureTest.java b/mr/src/test/java/org/apache/mahout/common/distance/DefaultDistanceMeasureTest.java new file mode 100644 index 0000000..ad1608c --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/distance/DefaultDistanceMeasureTest.java @@ -0,0 +1,103 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +public abstract class DefaultDistanceMeasureTest extends MahoutTestCase { + + protected abstract DistanceMeasure distanceMeasureFactory(); + + @Test + public void testMeasure() { + + DistanceMeasure distanceMeasure = distanceMeasureFactory(); + + Vector[] vectors = { + new DenseVector(new double[]{1, 1, 1, 1, 1, 1}), + new DenseVector(new double[]{2, 2, 2, 2, 2, 2}), + new DenseVector(new double[]{6, 6, 6, 6, 6, 6}), + new DenseVector(new double[]{-1,-1,-1,-1,-1,-1}) + }; + + compare(distanceMeasure, vectors); + + vectors = new Vector[4]; + + vectors[0] = new RandomAccessSparseVector(5); + vectors[0].setQuick(0, 1); + vectors[0].setQuick(3, 1); + vectors[0].setQuick(4, 1); + + vectors[1] = new RandomAccessSparseVector(5); + vectors[1].setQuick(0, 2); + vectors[1].setQuick(3, 2); + vectors[1].setQuick(4, 2); + + vectors[2] = new RandomAccessSparseVector(5); + vectors[2].setQuick(0, 6); + vectors[2].setQuick(3, 6); + vectors[2].setQuick(4, 6); + + vectors[3] = new RandomAccessSparseVector(5); + + compare(distanceMeasure, vectors); + } + + private static void compare(DistanceMeasure distanceMeasure, Vector[] vectors) { + double[][] distanceMatrix = new double[4][4]; + + for (int a = 0; a < 4; a++) { + for (int b = 0; b < 4; b++) { + distanceMatrix[a][b] = distanceMeasure.distance(vectors[a], vectors[b]); + } + } + + assertEquals("Distance from first vector to itself is not zero", 0.0, distanceMatrix[0][0], EPSILON); + assertTrue(distanceMatrix[0][0] < distanceMatrix[0][1]); + assertTrue(distanceMatrix[0][1] < distanceMatrix[0][2]); + + assertEquals("Distance from second vector to itself is not zero", 0.0, distanceMatrix[1][1], EPSILON); + assertTrue(distanceMatrix[1][0] > distanceMatrix[1][1]); + assertTrue(distanceMatrix[1][2] > distanceMatrix[1][0]); + + assertEquals("Distance from third vector to itself is not zero", 0.0, distanceMatrix[2][2], EPSILON); + assertTrue(distanceMatrix[2][0] > distanceMatrix[2][1]); + assertTrue(distanceMatrix[2][1] > distanceMatrix[2][2]); + + for (int a = 0; a < 4; a++) { + for (int b = 0; b < 4; b++) { + assertTrue("Distance between vectors less than zero: " + + distanceMatrix[a][b] + " = " + distanceMeasure + + ".distance("+ vectors[a].asFormatString() + ", " + + vectors[b].asFormatString() + ')', + distanceMatrix[a][b] >= 0); + if (vectors[a].plus(vectors[b]).norm(2) == 0 && vectors[a].norm(2) > 0) { + assertTrue("Distance from v to -v is equal to zero" + + vectors[a].asFormatString() + " = -" + vectors[b].asFormatString(), + distanceMatrix[a][b] > 0); + } + } + } + + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/distance/DefaultWeightedDistanceMeasureTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/distance/DefaultWeightedDistanceMeasureTest.java b/mr/src/test/java/org/apache/mahout/common/distance/DefaultWeightedDistanceMeasureTest.java new file mode 100644 index 0000000..a8f1d0b --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/distance/DefaultWeightedDistanceMeasureTest.java @@ -0,0 +1,56 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +public abstract class DefaultWeightedDistanceMeasureTest extends DefaultDistanceMeasureTest { + + @Override + public abstract WeightedDistanceMeasure distanceMeasureFactory(); + + @Test + public void testMeasureWeighted() { + + WeightedDistanceMeasure distanceMeasure = distanceMeasureFactory(); + + Vector[] vectors = { + new DenseVector(new double[]{9, 9, 1}), + new DenseVector(new double[]{1, 9, 9}), + new DenseVector(new double[]{9, 1, 9}), + }; + distanceMeasure.setWeights(new DenseVector(new double[]{1, 1000, 1})); + + double[][] distanceMatrix = new double[3][3]; + + for (int a = 0; a < 3; a++) { + for (int b = 0; b < 3; b++) { + distanceMatrix[a][b] = distanceMeasure.distance(vectors[a], vectors[b]); + } + } + + assertEquals(0.0, distanceMatrix[0][0], EPSILON); + assertTrue(distanceMatrix[0][1] < distanceMatrix[0][2]); + + + } + + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/distance/TestChebyshevMeasure.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/distance/TestChebyshevMeasure.java b/mr/src/test/java/org/apache/mahout/common/distance/TestChebyshevMeasure.java new file mode 100644 index 0000000..185adf3 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/distance/TestChebyshevMeasure.java @@ -0,0 +1,55 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +public final class TestChebyshevMeasure extends MahoutTestCase { + + @Test + public void testMeasure() { + + DistanceMeasure chebyshevDistanceMeasure = new ChebyshevDistanceMeasure(); + + Vector[] vectors = { + new DenseVector(new double[]{1, 0, 0, 0, 0, 0}), + new DenseVector(new double[]{1, 1, 1, 0, 0, 0}), + new DenseVector(new double[]{1, 1, 1, 1, 1, 1}) + }; + double[][] distances = {{0.0, 1.0, 1.0}, {1.0, 0.0, 1.0}, {1.0, 1.0, 0.0}}; + + double[][] chebyshevDistanceMatrix = new double[3][3]; + + for (int a = 0; a < 3; a++) { + for (int b = 0; b < 3; b++) { + chebyshevDistanceMatrix[a][b] = chebyshevDistanceMeasure.distance(vectors[a], vectors[b]); + } + } + for (int a = 0; a < 3; a++) { + for (int b = 0; b < 3; b++) { + assertEquals(distances[a][b], chebyshevDistanceMatrix[a][b], EPSILON); + } + } + + assertEquals(0.0, chebyshevDistanceMatrix[0][0], EPSILON); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/distance/TestEuclideanDistanceMeasure.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/distance/TestEuclideanDistanceMeasure.java b/mr/src/test/java/org/apache/mahout/common/distance/TestEuclideanDistanceMeasure.java new file mode 100644 index 0000000..cc9e9e7 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/distance/TestEuclideanDistanceMeasure.java @@ -0,0 +1,26 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +public final class TestEuclideanDistanceMeasure extends DefaultDistanceMeasureTest { + + @Override + public DistanceMeasure distanceMeasureFactory() { + return new EuclideanDistanceMeasure(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/distance/TestMahalanobisDistanceMeasure.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/distance/TestMahalanobisDistanceMeasure.java b/mr/src/test/java/org/apache/mahout/common/distance/TestMahalanobisDistanceMeasure.java new file mode 100644 index 0000000..8e3d205 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/distance/TestMahalanobisDistanceMeasure.java @@ -0,0 +1,56 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.junit.Test; + + +/** + * To launch this test only : mvn test -Dtest=org.apache.mahout.common.distance.TestMahalanobisDistanceMeasure + */ +public final class TestMahalanobisDistanceMeasure extends MahoutTestCase { + + @Test + public void testMeasure() { + double[][] invCovValues = { { 2.2, 0.4 }, { 0.4, 2.8 } }; + double[] meanValues = { -2.3, -0.9 }; + Matrix invCov = new DenseMatrix(invCovValues); + Vector meanVector = new DenseVector(meanValues); + MahalanobisDistanceMeasure distanceMeasure = new MahalanobisDistanceMeasure(); + distanceMeasure.setInverseCovarianceMatrix(invCov); + distanceMeasure.setMeanVector(meanVector); + double[] v1 = { -1.9, -2.3 }; + double[] v2 = { -2.9, -1.3 }; + double dist = distanceMeasure.distance(new DenseVector(v1),new DenseVector(v2)); + assertEquals(2.0493901531919194, dist, EPSILON); + //now set the covariance Matrix + distanceMeasure.setCovarianceMatrix(invCov); + //check the inverse covariance times covariance equals identity + Matrix identity = distanceMeasure.getInverseCovarianceMatrix().times(invCov); + assertEquals(1, identity.get(0,0), EPSILON); + assertEquals(1, identity.get(1,1), EPSILON); + assertEquals(0, identity.get(1,0), EPSILON); + assertEquals(0, identity.get(0,1), EPSILON); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/distance/TestManhattanDistanceMeasure.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/distance/TestManhattanDistanceMeasure.java b/mr/src/test/java/org/apache/mahout/common/distance/TestManhattanDistanceMeasure.java new file mode 100644 index 0000000..97a5612 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/distance/TestManhattanDistanceMeasure.java @@ -0,0 +1,26 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +public final class TestManhattanDistanceMeasure extends DefaultDistanceMeasureTest { + + @Override + public DistanceMeasure distanceMeasureFactory() { + return new ManhattanDistanceMeasure(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/distance/TestMinkowskiMeasure.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/distance/TestMinkowskiMeasure.java b/mr/src/test/java/org/apache/mahout/common/distance/TestMinkowskiMeasure.java new file mode 100644 index 0000000..d2cd85e --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/distance/TestMinkowskiMeasure.java @@ -0,0 +1,64 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +public final class TestMinkowskiMeasure extends MahoutTestCase { + + @Test + public void testMeasure() { + + DistanceMeasure minkowskiDistanceMeasure = new MinkowskiDistanceMeasure(1.5); + DistanceMeasure manhattanDistanceMeasure = new ManhattanDistanceMeasure(); + DistanceMeasure euclideanDistanceMeasure = new EuclideanDistanceMeasure(); + + Vector[] vectors = { + new DenseVector(new double[]{1, 0, 0, 0, 0, 0}), + new DenseVector(new double[]{1, 1, 1, 0, 0, 0}), + new DenseVector(new double[]{1, 1, 1, 1, 1, 1}) + }; + + double[][] minkowskiDistanceMatrix = new double[3][3]; + double[][] manhattanDistanceMatrix = new double[3][3]; + double[][] euclideanDistanceMatrix = new double[3][3]; + + for (int a = 0; a < 3; a++) { + for (int b = 0; b < 3; b++) { + minkowskiDistanceMatrix[a][b] = minkowskiDistanceMeasure.distance(vectors[a], vectors[b]); + manhattanDistanceMatrix[a][b] = manhattanDistanceMeasure.distance(vectors[a], vectors[b]); + euclideanDistanceMatrix[a][b] = euclideanDistanceMeasure.distance(vectors[a], vectors[b]); + } + } + + for (int a = 0; a < 3; a++) { + for (int b = 0; b < 3; b++) { + assertTrue(minkowskiDistanceMatrix[a][b] <= manhattanDistanceMatrix[a][b]); + assertTrue(minkowskiDistanceMatrix[a][b] >= euclideanDistanceMatrix[a][b]); + } + } + + assertEquals(0.0, minkowskiDistanceMatrix[0][0], EPSILON); + assertTrue(minkowskiDistanceMatrix[0][0] < minkowskiDistanceMatrix[0][1]); + assertTrue(minkowskiDistanceMatrix[0][1] < minkowskiDistanceMatrix[0][2]); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/distance/TestTanimotoDistanceMeasure.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/distance/TestTanimotoDistanceMeasure.java b/mr/src/test/java/org/apache/mahout/common/distance/TestTanimotoDistanceMeasure.java new file mode 100644 index 0000000..01f9134 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/distance/TestTanimotoDistanceMeasure.java @@ -0,0 +1,25 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +public final class TestTanimotoDistanceMeasure extends DefaultWeightedDistanceMeasureTest { + @Override + public TanimotoDistanceMeasure distanceMeasureFactory() { + return new TanimotoDistanceMeasure(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/distance/TestWeightedEuclideanDistanceMeasureTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/distance/TestWeightedEuclideanDistanceMeasureTest.java b/mr/src/test/java/org/apache/mahout/common/distance/TestWeightedEuclideanDistanceMeasureTest.java new file mode 100644 index 0000000..b99d165 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/distance/TestWeightedEuclideanDistanceMeasureTest.java @@ -0,0 +1,25 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +public final class TestWeightedEuclideanDistanceMeasureTest extends DefaultWeightedDistanceMeasureTest { + @Override + public WeightedDistanceMeasure distanceMeasureFactory() { + return new WeightedEuclideanDistanceMeasure(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/distance/TestWeightedManhattanDistanceMeasure.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/distance/TestWeightedManhattanDistanceMeasure.java b/mr/src/test/java/org/apache/mahout/common/distance/TestWeightedManhattanDistanceMeasure.java new file mode 100644 index 0000000..77d4a01 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/distance/TestWeightedManhattanDistanceMeasure.java @@ -0,0 +1,26 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +public final class TestWeightedManhattanDistanceMeasure extends DefaultWeightedDistanceMeasureTest { + + @Override + public WeightedManhattanDistanceMeasure distanceMeasureFactory() { + return new WeightedManhattanDistanceMeasure(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/iterator/CountingIteratorTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/iterator/CountingIteratorTest.java b/mr/src/test/java/org/apache/mahout/common/iterator/CountingIteratorTest.java new file mode 100644 index 0000000..d38178c --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/iterator/CountingIteratorTest.java @@ -0,0 +1,44 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.iterator; + +import java.util.Iterator; + +import org.apache.mahout.common.MahoutTestCase; +import org.junit.Test; + +public final class CountingIteratorTest extends MahoutTestCase { + + @Test + public void testEmptyCase() { + assertFalse(new CountingIterator(0).hasNext()); + } + + @Test + public void testCount() { + Iterator<Integer> it = new CountingIterator(3); + assertTrue(it.hasNext()); + assertEquals(0, (int) it.next()); + assertTrue(it.hasNext()); + assertEquals(1, (int) it.next()); + assertTrue(it.hasNext()); + assertEquals(2, (int) it.next()); + assertFalse(it.hasNext()); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/iterator/SamplerCase.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/iterator/SamplerCase.java b/mr/src/test/java/org/apache/mahout/common/iterator/SamplerCase.java new file mode 100644 index 0000000..b67d34b --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/iterator/SamplerCase.java @@ -0,0 +1,101 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.iterator; + +import java.util.Collections; +import java.util.Iterator; +import java.util.Arrays; +import java.util.List; + +import org.apache.mahout.common.MahoutTestCase; +import org.junit.Test; + +public abstract class SamplerCase extends MahoutTestCase { + // these provide access to the underlying implementation + protected abstract Iterator<Integer> createSampler(int n, Iterator<Integer> source); + + protected abstract boolean isSorted(); + + @Test + public void testEmptyCase() { + assertFalse(createSampler(100, new CountingIterator(0)).hasNext()); + } + + @Test + public void testSmallInput() { + Iterator<Integer> t = createSampler(10, new CountingIterator(1)); + assertTrue(t.hasNext()); + assertEquals(0, t.next().intValue()); + assertFalse(t.hasNext()); + + t = createSampler(10, new CountingIterator(1)); + assertTrue(t.hasNext()); + assertEquals(0, t.next().intValue()); + assertFalse(t.hasNext()); + } + + @Test + public void testAbsurdSize() { + Iterator<Integer> t = createSampler(0, new CountingIterator(2)); + assertFalse(t.hasNext()); + } + + @Test + public void testExactSizeMatch() { + Iterator<Integer> t = createSampler(10, new CountingIterator(10)); + for (int i = 0; i < 10; i++) { + assertTrue(t.hasNext()); + assertEquals(i, t.next().intValue()); + } + assertFalse(t.hasNext()); + } + + @Test + public void testSample() { + Iterator<Integer> source = new CountingIterator(100); + Iterator<Integer> t = createSampler(15, source); + + // this is just a regression test, not a real test + List<Integer> expectedValues = Arrays.asList(52,28,2,60,50,32,65,79,78,9,40,33,96,25,48); + if (isSorted()) { + Collections.sort(expectedValues); + } + Iterator<Integer> expected = expectedValues.iterator(); + int last = Integer.MIN_VALUE; + for (int i = 0; i < 15; i++) { + assertTrue(t.hasNext()); + int actual = t.next(); + if (isSorted()) { + assertTrue(actual >= last); + last = actual; + } else { + // any of the first few values should be in the original places + if (actual < 15) { + assertEquals(i, actual); + } + } + + assertTrue(actual >= 0 && actual < 100); + + // this is just a regression test, but still of some value + assertEquals(expected.next().intValue(), actual); + assertFalse(source.hasNext()); + } + assertFalse(t.hasNext()); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/iterator/TestFixedSizeSampler.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/iterator/TestFixedSizeSampler.java b/mr/src/test/java/org/apache/mahout/common/iterator/TestFixedSizeSampler.java new file mode 100644 index 0000000..470e6d8 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/iterator/TestFixedSizeSampler.java @@ -0,0 +1,33 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.iterator; + +import java.util.Iterator; + +public final class TestFixedSizeSampler extends SamplerCase { + + @Override + protected Iterator<Integer> createSampler(int n, Iterator<Integer> source) { + return new FixedSizeSamplingIterator<Integer>(n, source); + } + + @Override + protected boolean isSorted() { + return false; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/iterator/TestSamplingIterator.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/common/iterator/TestSamplingIterator.java b/mr/src/test/java/org/apache/mahout/common/iterator/TestSamplingIterator.java new file mode 100644 index 0000000..970ea79 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/common/iterator/TestSamplingIterator.java @@ -0,0 +1,77 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.iterator; + +import java.util.Iterator; + +import org.apache.mahout.common.MahoutTestCase; +import org.junit.Test; + +public final class TestSamplingIterator extends MahoutTestCase { + + @Test + public void testEmptyCase() { + assertFalse(new SamplingIterator<Integer>(new CountingIterator(0), 0.9999).hasNext()); + assertFalse(new SamplingIterator<Integer>(new CountingIterator(0), 1).hasNext()); + } + + @Test + public void testSmallInput() { + Iterator<Integer> t = new SamplingIterator<Integer>(new CountingIterator(1), 0.9999); + assertTrue(t.hasNext()); + assertEquals(0, t.next().intValue()); + assertFalse(t.hasNext()); + } + + @Test(expected = IllegalArgumentException.class) + public void testBadRate1() { + new SamplingIterator<Integer>(new CountingIterator(1), 0.0); + } + + @Test(expected = IllegalArgumentException.class) + public void testBadRate2() { + new SamplingIterator<Integer>(new CountingIterator(1), 1.1); + } + + @Test + public void testExactSizeMatch() { + Iterator<Integer> t = new SamplingIterator<Integer>(new CountingIterator(10), 1); + for (int i = 0; i < 10; i++) { + assertTrue(t.hasNext()); + assertEquals(i, t.next().intValue()); + } + assertFalse(t.hasNext()); + } + + @Test + public void testSample() { + for (int i = 0; i < 1000; i++) { + Iterator<Integer> t = new SamplingIterator<Integer>(new CountingIterator(1000), 0.1); + int k = 0; + while (t.hasNext()) { + int v = t.next(); + k++; + assertTrue(v >= 0); + assertTrue(v < 1000); + } + double sd = Math.sqrt(0.9 * 0.1 * 1000); + assertTrue(k >= 100 - 4 * sd); + assertTrue(k <= 100 + 4 * sd); + } + } +}
