http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/EuclideanDistanceSimilarityTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/EuclideanDistanceSimilarityTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/EuclideanDistanceSimilarityTest.java new file mode 100644 index 0000000..ca4d2b2 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/EuclideanDistanceSimilarityTest.java @@ -0,0 +1,236 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.similarity; + +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.common.Weighting; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.similarity.ItemSimilarity; +import org.junit.Test; + +/** <p>Tests {@link EuclideanDistanceSimilarity}.</p> */ +public final class EuclideanDistanceSimilarityTest extends SimilarityTestCase { + + @Test + public void testFullCorrelation1() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {3.0, -2.0}, + {3.0, -2.0}, + }); + double correlation = new EuclideanDistanceSimilarity(dataModel).userSimilarity(1, 2); + assertCorrelationEquals(1.0, correlation); + } + + @Test + public void testFullCorrelation1Weighted() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {3.0, -2.0}, + {3.0, -2.0}, + }); + double correlation = new EuclideanDistanceSimilarity(dataModel, Weighting.WEIGHTED).userSimilarity(1, 2); + assertCorrelationEquals(1.0, correlation); + } + + @Test + public void testFullCorrelation2() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {3.0, 3.0}, + {3.0, 3.0}, + }); + double correlation = new EuclideanDistanceSimilarity(dataModel).userSimilarity(1, 2); + assertEquals(1.0, correlation, EPSILON); + } + + @Test + public void testNoCorrelation1() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {3.0, -2.0}, + {-3.0, 2.0}, + }); + double correlation = new EuclideanDistanceSimilarity(dataModel).userSimilarity(1, 2); + assertCorrelationEquals(0.1639607805437114, correlation); + } + + @Test + public void testNoCorrelation1Weighted() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {3.0, -2.0}, + {-3.0, 2.0}, + }); + double correlation = new EuclideanDistanceSimilarity(dataModel, Weighting.WEIGHTED).userSimilarity(1, 2); + assertCorrelationEquals(0.7213202601812372, correlation); + } + + @Test + public void testNoCorrelation2() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {null, 1.0, null}, + {null, null, 1.0}, + }); + double correlation = new EuclideanDistanceSimilarity(dataModel).userSimilarity(1, 2); + assertTrue(Double.isNaN(correlation)); + } + + @Test + public void testNoCorrelation3() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {90.0, 80.0, 70.0}, + {70.0, 80.0, 90.0}, + }); + double correlation = new EuclideanDistanceSimilarity(dataModel).userSimilarity(1, 2); + assertCorrelationEquals(0.05770363219029305, correlation); + } + + @Test + public void testSimple() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {1.0, 2.0, 3.0}, + {2.0, 5.0, 6.0}, + }); + double correlation = new EuclideanDistanceSimilarity(dataModel).userSimilarity(1, 2); + assertCorrelationEquals(0.2843646522044218, correlation); + } + + @Test + public void testSimpleWeighted() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {1.0, 2.0, 3.0}, + {2.0, 5.0, 6.0}, + }); + double correlation = new EuclideanDistanceSimilarity(dataModel, Weighting.WEIGHTED).userSimilarity(1, 2); + assertCorrelationEquals(0.8210911630511055, correlation); + } + + @Test + public void testFullItemCorrelation1() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {3.0, 3.0}, + {-2.0, -2.0}, + }); + double correlation = + new EuclideanDistanceSimilarity(dataModel).itemSimilarity(0, 1); + assertCorrelationEquals(1.0, correlation); + } + + @Test + public void testFullItemCorrelation2() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {3.0, 3.0}, + {3.0, 3.0}, + }); + double correlation = + new EuclideanDistanceSimilarity(dataModel).itemSimilarity(0, 1); + assertEquals(1.0, correlation, EPSILON); + } + + @Test + public void testNoItemCorrelation1() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {3.0, -3.0}, + {-2.0, 2.0}, + }); + double correlation = + new EuclideanDistanceSimilarity(dataModel).itemSimilarity(0, 1); + assertCorrelationEquals(0.1639607805437114, correlation); + } + + @Test + public void testNoItemCorrelation2() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {null, 1.0, null}, + {null, null, 1.0}, + }); + double correlation = new EuclideanDistanceSimilarity(dataModel).itemSimilarity(1, 2); + assertTrue(Double.isNaN(correlation)); + } + + @Test + public void testNoItemCorrelation3() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2, 3}, + new Double[][] { + {90.0, 70.0}, + {80.0, 80.0}, + {70.0, 90.0}, + }); + double correlation = + new EuclideanDistanceSimilarity(dataModel).itemSimilarity(0, 1); + assertCorrelationEquals(0.05770363219029305, correlation); + } + + @Test + public void testSimpleItem() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2, 3}, + new Double[][] { + {1.0, 2.0}, + {2.0, 5.0}, + {3.0, 6.0}, + }); + double correlation = + new EuclideanDistanceSimilarity(dataModel).itemSimilarity(0, 1); + assertCorrelationEquals(0.2843646522044218, correlation); + } + + @Test + public void testSimpleItemWeighted() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2, 3}, + new Double[][] { + {1.0, 2.0}, + {2.0, 5.0}, + {3.0, 6.0}, + }); + ItemSimilarity itemSimilarity = new EuclideanDistanceSimilarity(dataModel, Weighting.WEIGHTED); + double correlation = itemSimilarity.itemSimilarity(0, 1); + assertCorrelationEquals(0.8210911630511055, correlation); + } + + @Test + public void testRefresh() throws TasteException { + // Make sure this doesn't throw an exception + new EuclideanDistanceSimilarity(getDataModel()).refresh(null); + } + +} \ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/GenericItemSimilarityTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/GenericItemSimilarityTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/GenericItemSimilarityTest.java new file mode 100644 index 0000000..5ce255c --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/GenericItemSimilarityTest.java @@ -0,0 +1,104 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.cf.taste.impl.similarity; + +import com.google.common.collect.Lists; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FastIDSet; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.similarity.ItemSimilarity; +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; + +/** <p>Tests {@link GenericItemSimilarity}.</p> */ +public final class GenericItemSimilarityTest extends SimilarityTestCase { + + @Test + public void testSimple() { + List<GenericItemSimilarity.ItemItemSimilarity> similarities = Lists.newArrayList(); + similarities.add(new GenericItemSimilarity.ItemItemSimilarity(1, 2, 0.5)); + similarities.add(new GenericItemSimilarity.ItemItemSimilarity(2, 1, 0.6)); + similarities.add(new GenericItemSimilarity.ItemItemSimilarity(1, 1, 0.5)); + similarities.add(new GenericItemSimilarity.ItemItemSimilarity(1, 3, 0.3)); + GenericItemSimilarity itemCorrelation = new GenericItemSimilarity(similarities); + assertEquals(1.0, itemCorrelation.itemSimilarity(1, 1), EPSILON); + assertEquals(0.6, itemCorrelation.itemSimilarity(1, 2), EPSILON); + assertEquals(0.6, itemCorrelation.itemSimilarity(2, 1), EPSILON); + assertEquals(0.3, itemCorrelation.itemSimilarity(1, 3), EPSILON); + assertTrue(Double.isNaN(itemCorrelation.itemSimilarity(3, 4))); + } + + @Test + public void testFromCorrelation() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2, 3}, + new Double[][] { + {1.0, 2.0}, + {2.0, 5.0}, + {3.0, 6.0}, + }); + ItemSimilarity otherSimilarity = new PearsonCorrelationSimilarity(dataModel); + ItemSimilarity itemSimilarity = new GenericItemSimilarity(otherSimilarity, dataModel); + assertCorrelationEquals(1.0, itemSimilarity.itemSimilarity(0, 0)); + assertCorrelationEquals(0.960768922830523, itemSimilarity.itemSimilarity(0, 1)); + } + + @Test + public void testAllSimilaritiesWithoutIndex() throws TasteException { + + List<GenericItemSimilarity.ItemItemSimilarity> itemItemSimilarities = + Arrays.asList(new GenericItemSimilarity.ItemItemSimilarity(1L, 2L, 0.2), + new GenericItemSimilarity.ItemItemSimilarity(1L, 3L, 0.2), + new GenericItemSimilarity.ItemItemSimilarity(2L, 1L, 0.2), + new GenericItemSimilarity.ItemItemSimilarity(3L, 5L, 0.2), + new GenericItemSimilarity.ItemItemSimilarity(3L, 4L, 0.2)); + + ItemSimilarity similarity = new GenericItemSimilarity(itemItemSimilarities); + + assertTrue(containsExactly(similarity.allSimilarItemIDs(1L), 2L, 3L)); + assertTrue(containsExactly(similarity.allSimilarItemIDs(2L), 1L)); + assertTrue(containsExactly(similarity.allSimilarItemIDs(3L), 1L, 5L, 4L)); + assertTrue(containsExactly(similarity.allSimilarItemIDs(4L), 3L)); + assertTrue(containsExactly(similarity.allSimilarItemIDs(5L), 3L)); + } + + @Test + public void testAllSimilaritiesWithIndex() throws TasteException { + + List<GenericItemSimilarity.ItemItemSimilarity> itemItemSimilarities = + Arrays.asList(new GenericItemSimilarity.ItemItemSimilarity(1L, 2L, 0.2), + new GenericItemSimilarity.ItemItemSimilarity(1L, 3L, 0.2), + new GenericItemSimilarity.ItemItemSimilarity(2L, 1L, 0.2), + new GenericItemSimilarity.ItemItemSimilarity(3L, 5L, 0.2), + new GenericItemSimilarity.ItemItemSimilarity(3L, 4L, 0.2)); + + ItemSimilarity similarity = new GenericItemSimilarity(itemItemSimilarities); + + assertTrue(containsExactly(similarity.allSimilarItemIDs(1L), 2L, 3L)); + assertTrue(containsExactly(similarity.allSimilarItemIDs(2L), 1L)); + assertTrue(containsExactly(similarity.allSimilarItemIDs(3L), 1L, 5L, 4L)); + assertTrue(containsExactly(similarity.allSimilarItemIDs(4L), 3L)); + assertTrue(containsExactly(similarity.allSimilarItemIDs(5L), 3L)); + } + + private static boolean containsExactly(long[] allIDs, long... shouldContainID) { + return new FastIDSet(allIDs).intersectionSize(new FastIDSet(shouldContainID)) == shouldContainID.length; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarityTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarityTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarityTest.java new file mode 100644 index 0000000..ae9df5c --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarityTest.java @@ -0,0 +1,80 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.similarity; + +import org.apache.mahout.cf.taste.model.DataModel; +import org.junit.Test; + +/** <p>Tests {@link LogLikelihoodSimilarity}.</p> */ +public final class LogLikelihoodSimilarityTest extends SimilarityTestCase { + + @Test + public void testCorrelation() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2, 3, 4, 5}, + new Double[][] { + {1.0, 1.0}, + {1.0, null, 1.0}, + {null, null, 1.0, 1.0, 1.0}, + {1.0, 1.0, 1.0, 1.0, 1.0}, + {null, 1.0, 1.0, 1.0, 1.0}, + }); + + LogLikelihoodSimilarity similarity = new LogLikelihoodSimilarity(dataModel); + + assertCorrelationEquals(0.12160727029227925, similarity.itemSimilarity(1, 0)); + assertCorrelationEquals(0.12160727029227925, similarity.itemSimilarity(0, 1)); + + assertCorrelationEquals(0.5423213660693732, similarity.itemSimilarity(1, 2)); + assertCorrelationEquals(0.5423213660693732, similarity.itemSimilarity(2, 1)); + + assertCorrelationEquals(0.6905400104897509, similarity.itemSimilarity(2, 3)); + assertCorrelationEquals(0.6905400104897509, similarity.itemSimilarity(3, 2)); + + assertCorrelationEquals(0.8706358464330881, similarity.itemSimilarity(3, 4)); + assertCorrelationEquals(0.8706358464330881, similarity.itemSimilarity(4, 3)); + } + + @Test + public void testNoSimilarity() throws Exception { + + DataModel dataModel = getDataModel( + new long[] {1, 2, 3, 4}, + new Double[][] { + {1.0, null, 1.0, 1.0}, + {1.0, null, 1.0, 1.0}, + {null, 1.0, 1.0, 1.0}, + {null, 1.0, 1.0, 1.0}, + }); + + LogLikelihoodSimilarity similarity = new LogLikelihoodSimilarity(dataModel); + + assertCorrelationEquals(Double.NaN, similarity.itemSimilarity(1, 0)); + assertCorrelationEquals(Double.NaN, similarity.itemSimilarity(0, 1)); + + assertCorrelationEquals(0.0, similarity.itemSimilarity(2, 3)); + assertCorrelationEquals(0.0, similarity.itemSimilarity(3, 2)); + } + + @Test + public void testRefresh() { + // Make sure this doesn't throw an exception + new LogLikelihoodSimilarity(getDataModel()).refresh(null); + } + +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/PearsonCorrelationSimilarityTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/PearsonCorrelationSimilarityTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/PearsonCorrelationSimilarityTest.java new file mode 100644 index 0000000..bb3ad3e --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/PearsonCorrelationSimilarityTest.java @@ -0,0 +1,265 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.similarity; + +import java.util.Collection; + +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.Weighting; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.similarity.ItemSimilarity; +import org.apache.mahout.cf.taste.similarity.PreferenceInferrer; +import org.apache.mahout.cf.taste.similarity.UserSimilarity; +import org.junit.Test; + +/** <p>Tests {@link PearsonCorrelationSimilarity}.</p> */ +public final class PearsonCorrelationSimilarityTest extends SimilarityTestCase { + + @Test + public void testFullCorrelation1() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {3.0, -2.0}, + {3.0, -2.0}, + }); + double correlation = new PearsonCorrelationSimilarity(dataModel).userSimilarity(1, 2); + assertCorrelationEquals(1.0, correlation); + } + + @Test + public void testFullCorrelation1Weighted() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {3.0, -2.0}, + {3.0, -2.0}, + }); + double correlation = new PearsonCorrelationSimilarity(dataModel, Weighting.WEIGHTED).userSimilarity(1, 2); + assertCorrelationEquals(1.0, correlation); + } + + @Test + public void testFullCorrelation2() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {3.0, 3.0}, + {3.0, 3.0}, + }); + double correlation = new PearsonCorrelationSimilarity(dataModel).userSimilarity(1, 2); + // Yeah, undefined in this case + assertTrue(Double.isNaN(correlation)); + } + + @Test + public void testNoCorrelation1() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {3.0, -2.0}, + {-3.0, 2.0}, + }); + double correlation = new PearsonCorrelationSimilarity(dataModel).userSimilarity(1, 2); + assertCorrelationEquals(-1.0, correlation); + } + + @Test + public void testNoCorrelation1Weighted() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {3.0, -2.0}, + {-3.0, 2.0}, + }); + double correlation = new PearsonCorrelationSimilarity(dataModel, Weighting.WEIGHTED).userSimilarity(1, 2); + assertCorrelationEquals(-1.0, correlation); + } + + @Test + public void testNoCorrelation2() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {null, 1.0, null}, + {null, null, 1.0}, + }); + double correlation = new PearsonCorrelationSimilarity(dataModel).userSimilarity(1, 2); + assertTrue(Double.isNaN(correlation)); + } + + @Test + public void testNoCorrelation3() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {90.0, 80.0, 70.0}, + {70.0, 80.0, 90.0}, + }); + double correlation = new PearsonCorrelationSimilarity(dataModel).userSimilarity(1, 2); + assertCorrelationEquals(-1.0, correlation); + } + + @Test + public void testSimple() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {1.0, 2.0, 3.0}, + {2.0, 5.0, 6.0}, + }); + double correlation = new PearsonCorrelationSimilarity(dataModel).userSimilarity(1, 2); + assertCorrelationEquals(0.9607689228305227, correlation); + } + + @Test + public void testSimpleWeighted() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {1.0, 2.0, 3.0}, + {2.0, 5.0, 6.0}, + }); + double correlation = new PearsonCorrelationSimilarity(dataModel, Weighting.WEIGHTED).userSimilarity(1, 2); + assertCorrelationEquals(0.9901922307076306, correlation); + } + + @Test + public void testFullItemCorrelation1() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {3.0, 3.0}, + {-2.0, -2.0}, + }); + double correlation = + new PearsonCorrelationSimilarity(dataModel).itemSimilarity(0, 1); + assertCorrelationEquals(1.0, correlation); + } + + @Test + public void testFullItemCorrelation2() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {3.0, 3.0}, + {3.0, 3.0}, + }); + double correlation = + new PearsonCorrelationSimilarity(dataModel).itemSimilarity(0, 1); + // Yeah, undefined in this case + assertTrue(Double.isNaN(correlation)); + } + + @Test + public void testNoItemCorrelation1() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {3.0, -3.0}, + {2.0, -2.0}, + }); + double correlation = + new PearsonCorrelationSimilarity(dataModel).itemSimilarity(0, 1); + assertCorrelationEquals(-1.0, correlation); + } + + @Test + public void testNoItemCorrelation2() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {null, 1.0, null}, + {null, null, 1.0}, + }); + double correlation = + new PearsonCorrelationSimilarity(dataModel).itemSimilarity(1, 2); + assertTrue(Double.isNaN(correlation)); + } + + @Test + public void testNoItemCorrelation3() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2, 3}, + new Double[][] { + {90.0, 70.0}, + {80.0, 80.0}, + {70.0, 90.0}, + }); + double correlation = + new PearsonCorrelationSimilarity(dataModel).itemSimilarity(0, 1); + assertCorrelationEquals(-1.0, correlation); + } + + @Test + public void testSimpleItem() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2, 3}, + new Double[][] { + {1.0, 2.0}, + {2.0, 5.0}, + {3.0, 6.0}, + }); + double correlation = + new PearsonCorrelationSimilarity(dataModel).itemSimilarity(0, 1); + assertCorrelationEquals(0.9607689228305227, correlation); + } + + @Test + public void testSimpleItemWeighted() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2, 3}, + new Double[][] { + {1.0, 2.0}, + {2.0, 5.0}, + {3.0, 6.0}, + }); + ItemSimilarity itemSimilarity = new PearsonCorrelationSimilarity(dataModel, Weighting.WEIGHTED); + double correlation = itemSimilarity.itemSimilarity(0, 1); + assertCorrelationEquals(0.9901922307076306, correlation); + } + + @Test + public void testRefresh() throws Exception { + // Make sure this doesn't throw an exception + new PearsonCorrelationSimilarity(getDataModel()).refresh(null); + } + + @Test + public void testInferrer() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {null, 1.0, 2.0, null, null, 6.0}, + {1.0, 8.0, null, 3.0, 4.0, null}, + }); + UserSimilarity similarity = new PearsonCorrelationSimilarity(dataModel); + similarity.setPreferenceInferrer(new PreferenceInferrer() { + @Override + public float inferPreference(long userID, long itemID) { + return 1.0f; + } + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + } + }); + + assertEquals(-0.435285750066007, similarity.userSimilarity(1L, 2L), EPSILON); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/SimilarityTestCase.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/SimilarityTestCase.java b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/SimilarityTestCase.java new file mode 100644 index 0000000..ad1e4b7 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/SimilarityTestCase.java @@ -0,0 +1,35 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.similarity; + +import org.apache.mahout.cf.taste.impl.TasteTestCase; + +abstract class SimilarityTestCase extends TasteTestCase { + + static void assertCorrelationEquals(double expected, double actual) { + if (Double.isNaN(expected)) { + assertTrue("Correlation is not NaN", Double.isNaN(actual)); + } else { + assertTrue("Correlation is NaN", !Double.isNaN(actual)); + assertTrue("Correlation > 1.0", actual <= 1.0); + assertTrue("Correlation < -1.0", actual >= -1.0); + assertEquals(expected, actual, EPSILON); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/SpearmanCorrelationSimilarityTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/SpearmanCorrelationSimilarityTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/SpearmanCorrelationSimilarityTest.java new file mode 100644 index 0000000..6034f4b --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/SpearmanCorrelationSimilarityTest.java @@ -0,0 +1,80 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.similarity; + +import org.apache.mahout.cf.taste.model.DataModel; +import org.junit.Test; + +/** <p>Tests {@link SpearmanCorrelationSimilarity}.</p> */ +public final class SpearmanCorrelationSimilarityTest extends SimilarityTestCase { + + @Test + public void testFullCorrelation1() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {1.0, 2.0, 3.0}, + {1.0, 2.0, 3.0}, + }); + double correlation = new SpearmanCorrelationSimilarity(dataModel).userSimilarity(1, 2); + assertCorrelationEquals(1.0, correlation); + } + + @Test + public void testFullCorrelation2() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + }); + double correlation = new SpearmanCorrelationSimilarity(dataModel).userSimilarity(1, 2); + assertCorrelationEquals(1.0, correlation); + } + + @Test + public void testAnticorrelation() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {1.0, 2.0, 3.0}, + {3.0, 2.0, 1.0}, + }); + double correlation = new SpearmanCorrelationSimilarity(dataModel).userSimilarity(1, 2); + assertCorrelationEquals(-1.0, correlation); + } + + @Test + public void testSimple() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {1.0, 2.0, 3.0}, + {2.0, 3.0, 1.0}, + }); + double correlation = new SpearmanCorrelationSimilarity(dataModel).userSimilarity(1, 2); + assertCorrelationEquals(-0.5, correlation); + } + + @Test + public void testRefresh() { + // Make sure this doesn't throw an exception + new SpearmanCorrelationSimilarity(getDataModel()).refresh(null); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/TanimotoCoefficientSimilarityTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/TanimotoCoefficientSimilarityTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/TanimotoCoefficientSimilarityTest.java new file mode 100644 index 0000000..87f82b9 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/TanimotoCoefficientSimilarityTest.java @@ -0,0 +1,121 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.similarity; + +import org.apache.mahout.cf.taste.model.DataModel; +import org.junit.Test; + +/** <p>Tests {@link TanimotoCoefficientSimilarity}.</p> */ +public final class TanimotoCoefficientSimilarityTest extends SimilarityTestCase { + + @Test + public void testNoCorrelation() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {null, 2.0, 3.0}, + {1.0}, + }); + double correlation = new TanimotoCoefficientSimilarity(dataModel).userSimilarity(1, 2); + assertCorrelationEquals(Double.NaN, correlation); + } + + @Test + public void testFullCorrelation1() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {1.0}, + {1.0}, + }); + double correlation = new TanimotoCoefficientSimilarity(dataModel).userSimilarity(1, 2); + assertCorrelationEquals(1.0, correlation); + } + + @Test + public void testFullCorrelation2() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {1.0, 2.0, 3.0}, + {1.0}, + }); + double correlation = new TanimotoCoefficientSimilarity(dataModel).userSimilarity(1, 2); + assertCorrelationEquals(0.3333333333333333, correlation); + } + + @Test + public void testCorrelation1() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {null, 2.0, 3.0}, + {1.0, 1.0}, + }); + double correlation = new TanimotoCoefficientSimilarity(dataModel).userSimilarity(1, 2); + assertEquals(0.3333333333333333, correlation, EPSILON); + } + + @Test + public void testCorrelation2() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {null, 2.0, 3.0, 1.0}, + {1.0, 1.0, null, 0.0}, + }); + double correlation = new TanimotoCoefficientSimilarity(dataModel).userSimilarity(1, 2); + assertEquals(0.5, correlation, EPSILON); + } + + @Test + public void testRefresh() { + // Make sure this doesn't throw an exception + new TanimotoCoefficientSimilarity(getDataModel()).refresh(null); + } + + @Test + public void testReturnNaNDoubleWhenNoSimilaritiesForTwoItems() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {null, null, 3.0}, + {1.0, 1.0, null}, + }); + Double similarity = new TanimotoCoefficientSimilarity(dataModel).itemSimilarity(1, 2); + assertEquals(Double.NaN, similarity, EPSILON); + } + + @Test + public void testItemsSimilarities() throws Exception { + DataModel dataModel = getDataModel( + new long[] {1, 2}, + new Double[][] { + {2.0, null, 2.0}, + {1.0, 1.0, 1.0}, + }); + TanimotoCoefficientSimilarity tCS = new TanimotoCoefficientSimilarity(dataModel); + assertEquals(0.5, tCS.itemSimilarity(0, 1), EPSILON); + assertEquals(1, tCS.itemSimilarity(0, 2), EPSILON); + + double[] similarities = tCS.itemSimilarities(0, new long [] {1, 2}); + assertEquals(0.5, similarities[0], EPSILON); + assertEquals(1, similarities[1], EPSILON); + } + +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/file/FileItemSimilarityTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/file/FileItemSimilarityTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/file/FileItemSimilarityTest.java new file mode 100644 index 0000000..d9d28ab --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/file/FileItemSimilarityTest.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.similarity.file; + +import java.io.File; + +import org.apache.mahout.cf.taste.impl.TasteTestCase; +import org.apache.mahout.cf.taste.impl.similarity.GenericItemSimilarity; +import org.apache.mahout.cf.taste.impl.similarity.GenericItemSimilarity.ItemItemSimilarity; +import org.apache.mahout.cf.taste.similarity.ItemSimilarity; +import org.junit.Before; +import org.junit.Test; + +/** <p>Tests {@link FileItemSimilarity}.</p> */ +public final class FileItemSimilarityTest extends TasteTestCase { + + private static final String[] data = { + "1,5,0.125", + "1,7,0.5" }; + + private static final String[] changedData = { + "1,5,0.125", + "1,7,0.9", + "7,8,0.112" }; + + private File testFile; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + testFile = getTestTempFile("test.txt"); + writeLines(testFile, data); + } + + @Test + public void testLoadFromFile() throws Exception { + ItemSimilarity similarity = new FileItemSimilarity(testFile); + + assertEquals(0.125, similarity.itemSimilarity(1L, 5L), EPSILON); + assertEquals(0.125, similarity.itemSimilarity(5L, 1L), EPSILON); + assertEquals(0.5, similarity.itemSimilarity(1L, 7L), EPSILON); + assertEquals(0.5, similarity.itemSimilarity(7L, 1L), EPSILON); + + assertTrue(Double.isNaN(similarity.itemSimilarity(7L, 8L))); + + double[] valuesForOne = similarity.itemSimilarities(1L, new long[] { 5L, 7L }); + assertNotNull(valuesForOne); + assertEquals(2, valuesForOne.length); + assertEquals(0.125, valuesForOne[0], EPSILON); + assertEquals(0.5, valuesForOne[1], EPSILON); + } + + @Test + public void testNoRefreshAfterFileUpdate() throws Exception { + ItemSimilarity similarity = new FileItemSimilarity(testFile, 0L); + + /* call a method to make sure the original file is loaded*/ + similarity.itemSimilarity(1L, 5L); + + /* change the underlying file, + * we have to wait at least a second to see the change in the file's lastModified timestamp */ + Thread.sleep(2000L); + writeLines(testFile, changedData); + + /* we shouldn't see any changes in the data as we have not yet refreshed */ + assertEquals(0.5, similarity.itemSimilarity(1L, 7L), EPSILON); + assertEquals(0.5, similarity.itemSimilarity(7L, 1L), EPSILON); + assertTrue(Double.isNaN(similarity.itemSimilarity(7L, 8L))); + } + + @Test + public void testRefreshAfterFileUpdate() throws Exception { + ItemSimilarity similarity = new FileItemSimilarity(testFile, 0L); + + /* call a method to make sure the original file is loaded */ + similarity.itemSimilarity(1L, 5L); + + /* change the underlying file, + * we have to wait at least a second to see the change in the file's lastModified timestamp */ + Thread.sleep(2000L); + writeLines(testFile, changedData); + + similarity.refresh(null); + + /* we should now see the changes in the data */ + assertEquals(0.9, similarity.itemSimilarity(1L, 7L), EPSILON); + assertEquals(0.9, similarity.itemSimilarity(7L, 1L), EPSILON); + assertEquals(0.125, similarity.itemSimilarity(1L, 5L), EPSILON); + assertEquals(0.125, similarity.itemSimilarity(5L, 1L), EPSILON); + + assertFalse(Double.isNaN(similarity.itemSimilarity(7L, 8L))); + assertEquals(0.112, similarity.itemSimilarity(7L, 8L), EPSILON); + assertEquals(0.112, similarity.itemSimilarity(8L, 7L), EPSILON); + } + + @Test(expected = IllegalArgumentException.class) + public void testFileNotFoundExceptionForNonExistingFile() throws Exception { + new FileItemSimilarity(new File("xKsdfksdfsdf")); + } + + @Test + public void testFileItemItemSimilarityIterable() throws Exception { + Iterable<ItemItemSimilarity> similarityIterable = new FileItemItemSimilarityIterable(testFile); + GenericItemSimilarity similarity = new GenericItemSimilarity(similarityIterable); + + assertEquals(0.125, similarity.itemSimilarity(1L, 5L), EPSILON); + assertEquals(0.125, similarity.itemSimilarity(5L, 1L), EPSILON); + assertEquals(0.5, similarity.itemSimilarity(1L, 7L), EPSILON); + assertEquals(0.5, similarity.itemSimilarity(7L, 1L), EPSILON); + + assertTrue(Double.isNaN(similarity.itemSimilarity(7L, 8L))); + + double[] valuesForOne = similarity.itemSimilarities(1L, new long[] { 5L, 7L }); + assertNotNull(valuesForOne); + assertEquals(2, valuesForOne.length); + assertEquals(0.125, valuesForOne[0], EPSILON); + assertEquals(0.5, valuesForOne[1], EPSILON); + } + + @Test + public void testToString() throws Exception { + ItemSimilarity similarity = new FileItemSimilarity(testFile); + assertTrue(!similarity.toString().isEmpty()); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/precompute/MultithreadedBatchItemSimilaritiesTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/precompute/MultithreadedBatchItemSimilaritiesTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/precompute/MultithreadedBatchItemSimilaritiesTest.java new file mode 100644 index 0000000..868e41a --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/precompute/MultithreadedBatchItemSimilaritiesTest.java @@ -0,0 +1,98 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * <p/> + * http://www.apache.org/licenses/LICENSE-2.0 + * <p/> + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.similarity.precompute; + +import java.io.IOException; +import org.apache.mahout.cf.taste.impl.common.FastByIDMap; +import org.apache.mahout.cf.taste.impl.model.GenericDataModel; +import org.apache.mahout.cf.taste.impl.model.GenericPreference; +import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray; +import org.apache.mahout.cf.taste.impl.recommender.GenericItemBasedRecommender; +import org.apache.mahout.cf.taste.impl.similarity.TanimotoCoefficientSimilarity; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.model.PreferenceArray; +import org.apache.mahout.cf.taste.recommender.ItemBasedRecommender; +import org.apache.mahout.cf.taste.similarity.precompute.BatchItemSimilarities; +import org.apache.mahout.cf.taste.similarity.precompute.SimilarItemsWriter; +import org.junit.Test; + +import java.util.Arrays; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; + +public class MultithreadedBatchItemSimilaritiesTest { + + @Test + public void lessItemsThanBatchSize() throws Exception { + + FastByIDMap<PreferenceArray> userData = new FastByIDMap<>(); + userData.put(1, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(1, 1, 1), + new GenericPreference(1, 2, 1), new GenericPreference(1, 3, 1)))); + userData.put(2, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(2, 1, 1), + new GenericPreference(2, 2, 1), new GenericPreference(2, 4, 1)))); + + DataModel dataModel = new GenericDataModel(userData); + ItemBasedRecommender recommender = + new GenericItemBasedRecommender(dataModel, new TanimotoCoefficientSimilarity(dataModel)); + + BatchItemSimilarities batchSimilarities = new MultithreadedBatchItemSimilarities(recommender, 10); + + batchSimilarities.computeItemSimilarities(1, 1, mock(SimilarItemsWriter.class)); + } + + @Test(expected = IOException.class) + public void higherDegreeOfParallelismThanBatches() throws Exception { + + FastByIDMap<PreferenceArray> userData = new FastByIDMap<>(); + userData.put(1, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(1, 1, 1), + new GenericPreference(1, 2, 1), new GenericPreference(1, 3, 1)))); + userData.put(2, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(2, 1, 1), + new GenericPreference(2, 2, 1), new GenericPreference(2, 4, 1)))); + + DataModel dataModel = new GenericDataModel(userData); + ItemBasedRecommender recommender = + new GenericItemBasedRecommender(dataModel, new TanimotoCoefficientSimilarity(dataModel)); + + BatchItemSimilarities batchSimilarities = new MultithreadedBatchItemSimilarities(recommender, 10); + + // Batch size is 100, so we only get 1 batch from 3 items, but we use a degreeOfParallelism of 2 + batchSimilarities.computeItemSimilarities(2, 1, mock(SimilarItemsWriter.class)); + fail(); + } + + @Test + public void testCorrectNumberOfOutputSimilarities() throws Exception { + FastByIDMap<PreferenceArray> userData = new FastByIDMap<>(); + userData.put(1, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(1, 1, 1), + new GenericPreference(1, 2, 1), new GenericPreference(1, 3, 1)))); + userData.put(2, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(2, 1, 1), + new GenericPreference(2, 2, 1), new GenericPreference(2, 4, 1)))); + + DataModel dataModel = new GenericDataModel(userData); + ItemBasedRecommender recommender = + new GenericItemBasedRecommender(dataModel, new TanimotoCoefficientSimilarity(dataModel)); + + BatchItemSimilarities batchSimilarities = new MultithreadedBatchItemSimilarities(recommender, 10, 2); + + int numOutputSimilarities = batchSimilarities.computeItemSimilarities(2, 1, mock(SimilarItemsWriter.class)); + assertEquals(numOutputSimilarities, 10); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItemsTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItemsTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItemsTest.java new file mode 100644 index 0000000..afce3cf --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItemsTest.java @@ -0,0 +1,50 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.similarity.precompute; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.mahout.cf.taste.impl.TasteTestCase; +import org.apache.mahout.cf.taste.impl.recommender.GenericRecommendedItem; +import org.apache.mahout.cf.taste.recommender.RecommendedItem; +import org.hamcrest.Matchers; +import org.junit.Test; + +public class SimilarItemsTest extends TasteTestCase { + + @Test + public void testIterator() { + List<RecommendedItem> recommendedItems = new ArrayList<>(); + for (long itemId = 2; itemId < 10; itemId++) { + recommendedItems.add(new GenericRecommendedItem(itemId, itemId)); + } + + SimilarItems similarItems = new SimilarItems(1, recommendedItems); + + assertThat(similarItems.getSimilarItems(), Matchers.<SimilarItem> iterableWithSize(recommendedItems.size())); + + int byHandIndex = 0; + for (SimilarItem simItem : similarItems.getSimilarItems()) { + RecommendedItem recItem = recommendedItems.get(byHandIndex++); + assertEquals(simItem.getItemID(), recItem.getItemID()); + assertEquals(simItem.getSimilarity(), recItem.getValue(), EPSILON); + } + + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/ClassifierData.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/ClassifierData.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/ClassifierData.java new file mode 100644 index 0000000..f037209 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/ClassifierData.java @@ -0,0 +1,102 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier; + +/** + * Class containing sample docs from ASF websites under mahout, lucene and spamassasin projects + * + */ +public final class ClassifierData { + + public static final String[][] DATA = { + { + "mahout", + "Mahout's goal is to build scalable machine learning libraries. With scalable we mean: " + + "Scalable to reasonably large data sets. Our core algorithms for clustering," + + " classfication and batch based collaborative filtering are implemented on top " + + "of Apache Hadoop using the map/reduce paradigm. However we do not restrict " + + "contributions to Hadoop based implementations: Contributions that run on"}, + { + "mahout", + " a single node or on a non-Hadoop cluster are welcome as well. The core" + + " libraries are highly optimized to allow for good performance also for" + + " non-distributed algorithms. Scalable to support your business case. " + + "Mahout is distributed under a commercially friendly Apache Software license. " + + "Scalable community. The goal of Mahout is to build a vibrant, responsive, "}, + { + "mahout", + "diverse community to facilitate discussions not only on the project itself" + + " but also on potential use cases. Come to the mailing lists to find out more." + + " Currently Mahout supports mainly four use cases: Recommendation mining takes " + + "users' behavior and from that tries to find items users might like. Clustering "}, + { + "mahout", + "takes e.g. text documents and groups them into groups of topically related documents." + + " Classification learns from exisiting categorized documents what documents of" + + " a specific category look like and is able to assign unlabelled documents to " + + "the (hopefully) correct category. Frequent itemset mining takes a set of item" + + " groups (terms in a query session, shopping cart content) and identifies, which" + + " individual items usually appear together."}, + { + "lucene", + "Apache Lucene is a high-performance, full-featured text search engine library" + + " written entirely in Java. It is a technology suitable for nearly any application " + + "that requires full-text search, especially cross-platform. Apache Lucene is an open source" + + " project available for free download. Please use the links on the left to access Lucene. " + + "The new version is mostly a cleanup release without any new features. "}, + { + "lucene", + "All deprecations targeted to be removed in version 3.0 were removed. If you " + + "are upgrading from version 2.9.1 of Lucene, you have to fix all deprecation warnings" + + " in your code base to be able to recompile against this version. This is the first Lucene"}, + { + "lucene", + " release with Java 5 as a minimum requirement. The API was cleaned up to make use of Java 5's " + + "generics, varargs, enums, and autoboxing. New users of Lucene are advised to use this version " + + "for new developments, because it has a clean, type safe new API. Upgrading users can now remove"}, + { + "lucene", + " unnecessary casts and add generics to their code, too. If you have not upgraded your installation " + + "to Java 5, please read the file JRE_VERSION_MIGRATION.txt (please note that this is not related to" + + " Lucene 3.0, it will also happen with any previous release when you upgrade your Java environment)."}, + { + "spamassasin", + "SpamAssassin is a mail filter to identify spam. It is an intelligent email filter which uses a diverse " + + "range of tests to identify unsolicited bulk email, more commonly known as Spam. These tests are applied " + + "to email headers and content to classify email using advanced statistical methods. In addition, "}, + { + "spamassasin", + "SpamAssassin has a modular architecture that allows other technologies to be quickly wielded against spam" + + " and is designed for easy integration into virtually any email system." + + "SpamAssassin's practical multi-technique approach, modularity, and extensibility continue to give it an "}, + { + "spamassasin", + "advantage over other anti-spam systems. Due to these advantages, SpamAssassin is widely used in all aspects " + + "of email management. You can readily find SpamAssassin in use in both email clients and servers, on many " + + "different operating systems, filtering incoming as well as outgoing email, and implementing a " + + "very broad range "}, + { + "spamassasin", + "of policy actions. These installations include service providers, businesses, not-for-profit and " + + "educational organizations, and end-user systems. SpamAssassin also forms the basis for numerous " + + "commercial anti-spam products available on the market today."}}; + + + private ClassifierData() { } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java new file mode 100644 index 0000000..3ffff85 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java @@ -0,0 +1,119 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Map; + +import com.google.common.collect.Lists; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.Matrix; +import org.junit.Test; + +public final class ConfusionMatrixTest extends MahoutTestCase { + + private static final int[][] VALUES = {{2, 3}, {10, 20}}; + private static final String[] LABELS = {"Label1", "Label2"}; + private static final int[] OTHER = {3, 6}; + private static final String DEFAULT_LABEL = "other"; + + @Test + public void testBuild() { + ConfusionMatrix confusionMatrix = fillConfusionMatrix(VALUES, LABELS, DEFAULT_LABEL); + checkValues(confusionMatrix); + checkAccuracy(confusionMatrix); + } + + @Test + public void testGetMatrix() { + ConfusionMatrix confusionMatrix = fillConfusionMatrix(VALUES, LABELS, DEFAULT_LABEL); + Matrix m = confusionMatrix.getMatrix(); + Map<String, Integer> rowLabels = m.getRowLabelBindings(); + + assertEquals(confusionMatrix.getLabels().size(), m.numCols()); + assertTrue(rowLabels.keySet().contains(LABELS[0])); + assertTrue(rowLabels.keySet().contains(LABELS[1])); + assertTrue(rowLabels.keySet().contains(DEFAULT_LABEL)); + assertEquals(2, confusionMatrix.getCorrect(LABELS[0])); + assertEquals(20, confusionMatrix.getCorrect(LABELS[1])); + assertEquals(0, confusionMatrix.getCorrect(DEFAULT_LABEL)); + } + + /** + * Example taken from + * http://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_fscore_support.html + */ + @Test + public void testPrecisionRecallAndF1ScoreAsScikitLearn() { + Collection<String> labelList = Arrays.asList("0", "1", "2"); + + ConfusionMatrix confusionMatrix = new ConfusionMatrix(labelList, "DEFAULT"); + confusionMatrix.putCount("0", "0", 2); + confusionMatrix.putCount("1", "0", 1); + confusionMatrix.putCount("1", "2", 1); + confusionMatrix.putCount("2", "1", 2); + + double delta = 0.001; + assertEquals(0.222, confusionMatrix.getWeightedPrecision(), delta); + assertEquals(0.333, confusionMatrix.getWeightedRecall(), delta); + assertEquals(0.266, confusionMatrix.getWeightedF1score(), delta); + } + + private static void checkValues(ConfusionMatrix cm) { + int[][] counts = cm.getConfusionMatrix(); + cm.toString(); + assertEquals(counts.length, counts[0].length); + assertEquals(3, counts.length); + assertEquals(VALUES[0][0], counts[0][0]); + assertEquals(VALUES[0][1], counts[0][1]); + assertEquals(VALUES[1][0], counts[1][0]); + assertEquals(VALUES[1][1], counts[1][1]); + assertTrue(Arrays.equals(new int[3], counts[2])); // zeros + assertEquals(OTHER[0], counts[0][2]); + assertEquals(OTHER[1], counts[1][2]); + assertEquals(3, cm.getLabels().size()); + assertTrue(cm.getLabels().contains(LABELS[0])); + assertTrue(cm.getLabels().contains(LABELS[1])); + assertTrue(cm.getLabels().contains(DEFAULT_LABEL)); + } + + private static void checkAccuracy(ConfusionMatrix cm) { + Collection<String> labelstrs = cm.getLabels(); + assertEquals(3, labelstrs.size()); + assertEquals(25.0, cm.getAccuracy("Label1"), EPSILON); + assertEquals(55.5555555, cm.getAccuracy("Label2"), EPSILON); + assertTrue(Double.isNaN(cm.getAccuracy("other"))); + } + + private static ConfusionMatrix fillConfusionMatrix(int[][] values, String[] labels, String defaultLabel) { + Collection<String> labelList = Lists.newArrayList(); + labelList.add(labels[0]); + labelList.add(labels[1]); + ConfusionMatrix confusionMatrix = new ConfusionMatrix(labelList, defaultLabel); + + confusionMatrix.putCount("Label1", "Label1", values[0][0]); + confusionMatrix.putCount("Label1", "Label2", values[0][1]); + confusionMatrix.putCount("Label2", "Label1", values[1][0]); + confusionMatrix.putCount("Label2", "Label2", values[1][1]); + confusionMatrix.putCount("Label1", DEFAULT_LABEL, OTHER[0]); + confusionMatrix.putCount("Label2", DEFAULT_LABEL, OTHER[1]); + return confusionMatrix; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/RegressionResultAnalyzerTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/RegressionResultAnalyzerTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/RegressionResultAnalyzerTest.java new file mode 100644 index 0000000..86234f8 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/RegressionResultAnalyzerTest.java @@ -0,0 +1,128 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import org.apache.mahout.common.MahoutTestCase; +import org.junit.Test; + +public class RegressionResultAnalyzerTest extends MahoutTestCase { + + private static final Pattern p1 = Pattern.compile("Correlation coefficient *: *(.*)\n"); + private static final Pattern p2 = Pattern.compile("Mean absolute error *: *(.*)\n"); + private static final Pattern p3 = Pattern.compile("Root mean squared error *: *(.*)\n"); + private static final Pattern p4 = Pattern.compile("Predictable Instances *: *(.*)\n"); + private static final Pattern p5 = Pattern.compile("Unpredictable Instances *: *(.*)\n"); + private static final Pattern p6 = Pattern.compile("Total Regressed Instances *: *(.*)\n"); + + private static double[] parseAnalysis(CharSequence analysis) { + double[] results = new double[3]; + Matcher m = p1.matcher(analysis); + if (m.find()) { + results[0] = Double.parseDouble(m.group(1)); + } else { + return null; + } + m = p2.matcher(analysis); + if (m.find()) { + results[1] = Double.parseDouble(m.group(1)); + } else { + return null; + } + m = p3.matcher(analysis); + if (m.find()) { + results[2] = Double.parseDouble(m.group(1)); + } else { + return null; + } + return results; + } + + private static int[] parseAnalysisCount(CharSequence analysis) { + int[] results = new int[3]; + Matcher m = p4.matcher(analysis); + if (m.find()) { + results[0] = Integer.parseInt(m.group(1)); + } + m = p5.matcher(analysis); + if (m.find()) { + results[1] = Integer.parseInt(m.group(1)); + } + m = p6.matcher(analysis); + if (m.find()) { + results[2] = Integer.parseInt(m.group(1)); + } + return results; + } + + @Test + public void testAnalyze() { + double[][] results = new double[10][2]; + + for (int i = 0; i < results.length; i++) { + results[i][0] = i; + results[i][1] = i + 1; + } + RegressionResultAnalyzer analyzer = new RegressionResultAnalyzer(); + analyzer.setInstances(results); + String analysis = analyzer.toString(); + assertArrayEquals(new double[]{1.0, 1.0, 1.0}, parseAnalysis(analysis), 0); + + for (int i = 0; i < results.length; i++) { + results[i][1] = Math.sqrt(i); + } + analyzer = new RegressionResultAnalyzer(); + analyzer.setInstances(results); + analysis = analyzer.toString(); + assertArrayEquals(new double[]{0.9573, 2.5694, 3.2848}, parseAnalysis(analysis), 0); + + for (int i = 0; i < results.length; i++) { + results[i][0] = results.length - i; + } + analyzer = new RegressionResultAnalyzer(); + analyzer.setInstances(results); + analysis = analyzer.toString(); + assertArrayEquals(new double[]{-0.9573, 4.1351, 5.1573}, parseAnalysis(analysis), 0); + } + + @Test + public void testUnpredictable() { + double[][] results = new double[10][2]; + + for (int i = 0; i < results.length; i++) { + results[i][0] = i; + results[i][1] = Double.NaN; + } + RegressionResultAnalyzer analyzer = new RegressionResultAnalyzer(); + analyzer.setInstances(results); + String analysis = analyzer.toString(); + assertNull(parseAnalysis(analysis)); + assertArrayEquals(new int[]{0, 10, 10}, parseAnalysisCount(analysis)); + + for (int i = 0; i < results.length - 3; i++) { + results[i][1] = Math.sqrt(i); + } + analyzer = new RegressionResultAnalyzer(); + analyzer.setInstances(results); + analysis = analyzer.toString(); + assertArrayEquals(new double[]{0.9552, 1.4526, 1.9345}, parseAnalysis(analysis), 0); + assertArrayEquals(new int[]{7, 3, 10}, parseAnalysisCount(analysis)); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/DecisionForestTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/DecisionForestTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/DecisionForestTest.java new file mode 100644 index 0000000..036d473 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/DecisionForestTest.java @@ -0,0 +1,206 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df; + +import java.util.List; +import java.util.Random; + +import org.apache.mahout.classifier.df.builder.DecisionTreeBuilder; +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.data.DataLoader; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.DescriptorException; +import org.apache.mahout.classifier.df.data.Instance; +import org.apache.mahout.classifier.df.node.Node; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.RandomUtils; +import org.junit.Test; + +import com.google.common.collect.Lists; +@Deprecated +public final class DecisionForestTest extends MahoutTestCase { + + private static final String[] TRAIN_DATA = {"sunny,85,85,FALSE,no", + "sunny,80,90,TRUE,no", "overcast,83,86,FALSE,yes", + "rainy,70,96,FALSE,yes", "rainy,68,80,FALSE,yes", "rainy,65,70,TRUE,no", + "overcast,64,65,TRUE,yes", "sunny,72,95,FALSE,no", + "sunny,69,70,FALSE,yes", "rainy,75,80,FALSE,yes", "sunny,75,70,TRUE,yes", + "overcast,72,90,TRUE,yes", "overcast,81,75,FALSE,yes", + "rainy,71,91,TRUE,no"}; + + private static final String[] TEST_DATA = {"rainy,70,96,TRUE,-", + "overcast,64,65,TRUE,-", "sunny,75,90,TRUE,-",}; + + private Random rng; + + @Override + public void setUp() throws Exception { + super.setUp(); + rng = RandomUtils.getRandom(); + } + + private static Data[] generateTrainingDataA() throws DescriptorException { + // Dataset + Dataset dataset = DataLoader.generateDataset("C N N C L", false, TRAIN_DATA); + + // Training data + Data data = DataLoader.loadData(dataset, TRAIN_DATA); + @SuppressWarnings("unchecked") + List<Instance>[] instances = new List[3]; + for (int i = 0; i < instances.length; i++) { + instances[i] = Lists.newArrayList(); + } + for (int i = 0; i < data.size(); i++) { + if (data.get(i).get(0) == 0.0d) { + instances[0].add(data.get(i)); + } else { + instances[1].add(data.get(i)); + } + } + Data[] datas = new Data[instances.length]; + for (int i = 0; i < datas.length; i++) { + datas[i] = new Data(dataset, instances[i]); + } + + return datas; + } + + private static Data[] generateTrainingDataB() throws DescriptorException { + + // Training data + String[] trainData = new String[20]; + for (int i = 0; i < trainData.length; i++) { + if (i % 3 == 0) { + trainData[i] = "A," + (40 - i) + ',' + (i + 20); + } else if (i % 3 == 1) { + trainData[i] = "B," + (i + 20) + ',' + (40 - i); + } else { + trainData[i] = "C," + (i + 20) + ',' + (i + 20); + } + } + // Dataset + Dataset dataset = DataLoader.generateDataset("C N L", true, trainData); + Data[] datas = new Data[3]; + datas[0] = DataLoader.loadData(dataset, trainData); + + // Training data + trainData = new String[20]; + for (int i = 0; i < trainData.length; i++) { + if (i % 2 == 0) { + trainData[i] = "A," + (50 - i) + ',' + (i + 10); + } else { + trainData[i] = "B," + (i + 10) + ',' + (50 - i); + } + } + datas[1] = DataLoader.loadData(dataset, trainData); + + // Training data + trainData = new String[10]; + for (int i = 0; i < trainData.length; i++) { + trainData[i] = "A," + (40 - i) + ',' + (i + 20); + } + datas[2] = DataLoader.loadData(dataset, trainData); + + return datas; + } + + private DecisionForest buildForest(Data[] datas) { + List<Node> trees = Lists.newArrayList(); + for (Data data : datas) { + // build tree + DecisionTreeBuilder builder = new DecisionTreeBuilder(); + builder.setM(data.getDataset().nbAttributes() - 1); + builder.setMinSplitNum(0); + builder.setComplemented(false); + trees.add(builder.build(rng, data)); + } + return new DecisionForest(trees); + } + + @Test + public void testClassify() throws DescriptorException { + // Training data + Data[] datas = generateTrainingDataA(); + // Build Forest + DecisionForest forest = buildForest(datas); + // Test data + Dataset dataset = datas[0].getDataset(); + Data testData = DataLoader.loadData(dataset, TEST_DATA); + + double noValue = dataset.valueOf(4, "no"); + double yesValue = dataset.valueOf(4, "yes"); + assertEquals(noValue, forest.classify(testData.getDataset(), rng, testData.get(0)), EPSILON); + // This one is tie-broken -- 1 is OK too + //assertEquals(yesValue, forest.classify(testData.getDataset(), rng, testData.get(1)), EPSILON); + assertEquals(noValue, forest.classify(testData.getDataset(), rng, testData.get(2)), EPSILON); + } + + @Test + public void testClassifyData() throws DescriptorException { + // Training data + Data[] datas = generateTrainingDataA(); + // Build Forest + DecisionForest forest = buildForest(datas); + // Test data + Dataset dataset = datas[0].getDataset(); + Data testData = DataLoader.loadData(dataset, TEST_DATA); + + double[][] predictions = new double[testData.size()][]; + forest.classify(testData, predictions); + double noValue = dataset.valueOf(4, "no"); + double yesValue = dataset.valueOf(4, "yes"); + assertArrayEquals(new double[][]{{noValue, Double.NaN, Double.NaN}, + {noValue, yesValue, Double.NaN}, {noValue, noValue, Double.NaN}}, predictions); + } + + @Test + public void testRegression() throws DescriptorException { + Data[] datas = generateTrainingDataB(); + DecisionForest[] forests = new DecisionForest[datas.length]; + for (int i = 0; i < datas.length; i++) { + Data[] subDatas = new Data[datas.length - 1]; + int k = 0; + for (int j = 0; j < datas.length; j++) { + if (j != i) { + subDatas[k] = datas[j]; + k++; + } + } + forests[i] = buildForest(subDatas); + } + + double[][] predictions = new double[datas[0].size()][]; + forests[0].classify(datas[0], predictions); + assertArrayEquals(new double[]{20.0, 20.0}, predictions[0], EPSILON); + assertArrayEquals(new double[]{39.0, 29.0}, predictions[1], EPSILON); + assertArrayEquals(new double[]{Double.NaN, 29.0}, predictions[2], EPSILON); + assertArrayEquals(new double[]{Double.NaN, 23.0}, predictions[17], EPSILON); + + predictions = new double[datas[1].size()][]; + forests[1].classify(datas[1], predictions); + assertArrayEquals(new double[]{30.0, 29.0}, predictions[19], EPSILON); + + predictions = new double[datas[2].size()][]; + forests[2].classify(datas[2], predictions); + assertArrayEquals(new double[]{29.0, 28.0}, predictions[9], EPSILON); + + assertEquals(20.0, forests[0].classify(datas[0].getDataset(), rng, datas[0].get(0)), EPSILON); + assertEquals(34.0, forests[0].classify(datas[0].getDataset(), rng, datas[0].get(1)), EPSILON); + assertEquals(29.0, forests[0].classify(datas[0].getDataset(), rng, datas[0].get(2)), EPSILON); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilderTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilderTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilderTest.java new file mode 100644 index 0000000..56b4787 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilderTest.java @@ -0,0 +1,78 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.builder; + +import java.lang.reflect.Method; +import java.util.Random; +import java.util.Arrays; + +import org.apache.commons.lang3.ArrayUtils; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.RandomUtils; +import org.junit.Test; +@Deprecated +public final class DecisionTreeBuilderTest extends MahoutTestCase { + + /** + * make sure that DecisionTreeBuilder.randomAttributes() returns the correct number of attributes, that have not been + * selected yet + */ + @Test + public void testRandomAttributes() throws Exception { + Random rng = RandomUtils.getRandom(); + int nbAttributes = rng.nextInt(100) + 1; + boolean[] selected = new boolean[nbAttributes]; + + for (int nloop = 0; nloop < 100; nloop++) { + Arrays.fill(selected, false); + + // randomly select some attributes + int nbSelected = rng.nextInt(nbAttributes - 1); + for (int index = 0; index < nbSelected; index++) { + int attr; + do { + attr = rng.nextInt(nbAttributes); + } while (selected[attr]); + + selected[attr] = true; + } + + int m = rng.nextInt(nbAttributes); + + Method randomAttributes = DecisionTreeBuilder.class.getDeclaredMethod("randomAttributes", + Random.class, boolean[].class, int.class); + randomAttributes.setAccessible(true); + int[] attrs = (int[]) randomAttributes.invoke(null, rng, selected, m); + + assertNotNull(attrs); + assertEquals(Math.min(m, nbAttributes - nbSelected), attrs.length); + + for (int attr : attrs) { + // the attribute should not be already selected + assertFalse("an attribute has already been selected", selected[attr]); + + // each attribute should be in the range [0, nbAttributes[ + assertTrue(attr >= 0); + assertTrue(attr < nbAttributes); + + // each attribute should appear only once + assertEquals(ArrayUtils.indexOf(attrs, attr), ArrayUtils.lastIndexOf(attrs, attr)); + } + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilderTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilderTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilderTest.java new file mode 100644 index 0000000..87fd44b --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilderTest.java @@ -0,0 +1,74 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.builder; + +import java.util.Random; +import java.util.Arrays; + +import org.apache.commons.lang3.ArrayUtils; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.RandomUtils; +import org.junit.Test; +@Deprecated +public final class DefaultTreeBuilderTest extends MahoutTestCase { + + /** + * make sure that DefaultTreeBuilder.randomAttributes() returns the correct number of attributes, that have not been + * selected yet + */ + @Test + public void testRandomAttributes() throws Exception { + Random rng = RandomUtils.getRandom(); + int nbAttributes = rng.nextInt(100) + 1; + boolean[] selected = new boolean[nbAttributes]; + + for (int nloop = 0; nloop < 100; nloop++) { + Arrays.fill(selected, false); + + // randomly select some attributes + int nbSelected = rng.nextInt(nbAttributes - 1); + for (int index = 0; index < nbSelected; index++) { + int attr; + do { + attr = rng.nextInt(nbAttributes); + } while (selected[attr]); + + selected[attr] = true; + } + + int m = rng.nextInt(nbAttributes); + + int[] attrs = DefaultTreeBuilder.randomAttributes(rng, selected, m); + + assertNotNull(attrs); + assertEquals(Math.min(m, nbAttributes - nbSelected), attrs.length); + + for (int attr : attrs) { + // the attribute should not be already selected + assertFalse("an attribute has already been selected", selected[attr]); + + // each attribute should be in the range [0, nbAttributes[ + assertTrue(attr >= 0); + assertTrue(attr < nbAttributes); + + // each attribute should appear only once + assertEquals(ArrayUtils.indexOf(attrs, attr), ArrayUtils.lastIndexOf(attrs, attr)); + } + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/builder/InfiniteRecursionTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/builder/InfiniteRecursionTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/builder/InfiniteRecursionTest.java new file mode 100644 index 0000000..8ebc721 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/builder/InfiniteRecursionTest.java @@ -0,0 +1,60 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.builder; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.data.DataLoader; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.Utils; +import org.junit.Test; + +import java.util.Random; +@Deprecated +public final class InfiniteRecursionTest extends MahoutTestCase { + + private static final double[][] dData = { + { 0.25, 0.0, 0.0, 5.143998668220409E-4, 0.019847102289905324, 3.5216524641879855E-4, 0.0, 0.6225857142857143, 4 }, + { 0.25, 0.0, 0.0, 0.0010504411519893459, 0.005462138323171171, 0.0026130744829756746, 0.0, 0.4964857142857143, 3 }, + { 0.25, 0.0, 0.0, 0.0010504411519893459, 0.005462138323171171, 0.0026130744829756746, 0.0, 0.4964857142857143, 4 }, + { 0.25, 0.0, 0.0, 5.143998668220409E-4, 0.019847102289905324, 3.5216524641879855E-4, 0.0, 0.6225857142857143, 3 } + }; + + /** + * make sure DecisionTreeBuilder.build() does not throw a StackOverflowException + */ + @Test + public void testBuild() throws Exception { + Random rng = RandomUtils.getRandom(); + + String[] source = Utils.double2String(dData); + String descriptor = "N N N N N N N N L"; + + Dataset dataset = DataLoader.generateDataset(descriptor, false, source); + Data data = DataLoader.loadData(dataset, source); + TreeBuilder builder = new DecisionTreeBuilder(); + builder.build(rng, data); + + // regression + dataset = DataLoader.generateDataset(descriptor, true, source); + data = DataLoader.loadData(dataset, source); + builder = new DecisionTreeBuilder(); + builder.build(rng, data); + } +}
