Repository: mahout Updated Branches: refs/heads/branch-0.14.0 410ed16af -> 49ad8cb45
http://git-wip-us.apache.org/repos/asf/mahout/blob/49ad8cb4/core/src/test/java/org/apache/mahout/math/VectorBinaryAssignCostTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/VectorBinaryAssignCostTest.java b/core/src/test/java/org/apache/mahout/math/VectorBinaryAssignCostTest.java index 61404be..04d7f72 100644 --- a/core/src/test/java/org/apache/mahout/math/VectorBinaryAssignCostTest.java +++ b/core/src/test/java/org/apache/mahout/math/VectorBinaryAssignCostTest.java @@ -1,245 +1,243 @@ -// -///* -// * Licensed to the Apache Software Foundation (ASF) under one or more contributor license -// * agreements. See the NOTICE file distributed with this work for additional information regarding -// * copyright ownership. The ASF licenses this file to You under the Apache License, Version 2.0 (the -// * "License"); you may not use this file except in compliance with the License. You may obtain a -// * copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software distributed under the License -// * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// * or implied. See the License for the specific language governing permissions and limitations under -// * the License. -// */ -// -// TODO; Do we need to bring this back? -// -//package org.apache.mahout.math; -// -//import org.apache.mahout.math.function.Functions; -//import org.easymock.EasyMock; -//import org.junit.Before; -//import org.junit.Test; -//import org.junit.runner.RunWith; -//import org.junit.runners.JUnit4; -// -//import static org.easymock.EasyMock.expect; -//import static org.easymock.EasyMock.replay; -//import static org.junit.Assert.assertEquals; -// -//@RunWith(JUnit4.class) -//public final class VectorBinaryAssignCostTest { -// RandomAccessSparseVector realRasv = new RandomAccessSparseVector(1000000); -// SequentialAccessSparseVector realSasv = new SequentialAccessSparseVector(1000000); -// DenseVector realDense = new DenseVector(1000000); -// -// Vector rasv = EasyMock.createMock(Vector.class); -// Vector sasv = EasyMock.createMock(Vector.class); -// Vector dense = EasyMock.createMock(Vector.class); -// -// private static void createStubs(Vector v, Vector realV) { -// expect(v.getLookupCost()) -// .andStubReturn(realV instanceof SequentialAccessSparseVector -// ? Math.round(Math.log(1000)) : realV.getLookupCost()); -// expect(v.getIteratorAdvanceCost()) -// .andStubReturn(realV.getIteratorAdvanceCost()); -// expect(v.isAddConstantTime()) -// .andStubReturn(realV.isAddConstantTime()); -// expect(v.isSequentialAccess()) -// .andStubReturn(realV.isSequentialAccess()); -// expect(v.isDense()) -// .andStubReturn(realV.isDense()); -// expect(v.getNumNondefaultElements()) -// .andStubReturn(realV.isDense() ? realV.size() : 1000); -// expect(v.size()) -// .andStubReturn(realV.size()); -// } -// -// @Before -// public void setUpStubs() { -// createStubs(dense, realDense); -// createStubs(sasv, realSasv); -// createStubs(rasv, realRasv); -// } -// -// @Test -// public void denseInteractions() { -// replayAll(); -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(dense, dense, Functions.PLUS).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(dense, dense, Functions.MINUS).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThisLookupThat.class, -// VectorBinaryAssign.getBestOperation(dense, dense, Functions.MULT).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignAllLoopInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(dense, dense, Functions.DIV).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(dense, dense, Functions.SECOND_LEFT_ZERO).getClass()); -// } -// -// @Test -// public void sasvInteractions() { -// replayAll(); -// -// assertEquals(VectorBinaryAssign.AssignIterateUnionSequentialMergeUpdates.class, -// VectorBinaryAssign.getBestOperation(sasv, sasv, Functions.PLUS).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignIterateUnionSequentialMergeUpdates.class, -// VectorBinaryAssign.getBestOperation(sasv, sasv, Functions.MINUS).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignIterateUnionSequentialMergeUpdates.class, -// VectorBinaryAssign.getBestOperation(sasv, sasv, Functions.MULT).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignAllIterateSequentialMergeUpdates.class, -// VectorBinaryAssign.getBestOperation(sasv, sasv, Functions.DIV).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignIterateUnionSequentialMergeUpdates.class, -// VectorBinaryAssign.getBestOperation(sasv, sasv, Functions.SECOND_LEFT_ZERO).getClass()); -// } -// -// @Test -// public void rasvInteractions() { -// replayAll(); -// -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(rasv, rasv, Functions.PLUS).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(rasv, rasv, Functions.MINUS).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThisLookupThat.class, -// VectorBinaryAssign.getBestOperation(rasv, rasv, Functions.MULT).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignAllLoopInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(rasv, rasv, Functions.DIV).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(rasv, rasv, Functions.SECOND_LEFT_ZERO).getClass()); -// } -// -// @Test -// public void sasvDenseInteractions() { -// replayAll(); -// -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisMergeUpdates.class, -// VectorBinaryAssign.getBestOperation(sasv, dense, Functions.PLUS).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisMergeUpdates.class, -// VectorBinaryAssign.getBestOperation(sasv, dense, Functions.MINUS).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThisLookupThat.class, -// VectorBinaryAssign.getBestOperation(sasv, dense, Functions.MULT).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignAllIterateThisLookupThatMergeUpdates.class, -// VectorBinaryAssign.getBestOperation(sasv, dense, Functions.DIV).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisMergeUpdates.class, -// VectorBinaryAssign.getBestOperation(sasv, dense, Functions.SECOND_LEFT_ZERO).getClass()); -// } -// -// @Test -// public void denseSasvInteractions() { -// replayAll(); -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(dense, sasv, Functions.PLUS).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(dense, sasv, Functions.MINUS).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignIterateUnionSequentialInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(dense, sasv, Functions.MULT).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignAllIterateThatLookupThisInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(dense, sasv, Functions.DIV).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(dense, sasv, Functions.SECOND_LEFT_ZERO).getClass()); -// } -// -// @Test -// public void denseRasvInteractions() { -// replayAll(); -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(dense, rasv, Functions.PLUS).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(dense, rasv, Functions.MINUS).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThisLookupThat.class, -// VectorBinaryAssign.getBestOperation(dense, rasv, Functions.MULT).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignAllLoopInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(dense, rasv, Functions.DIV).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(dense, rasv, Functions.SECOND_LEFT_ZERO).getClass()); -// } -// -// @Test -// public void rasvDenseInteractions() { -// replayAll(); -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(rasv, dense, Functions.PLUS).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(rasv, dense, Functions.MINUS).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThisLookupThat.class, -// VectorBinaryAssign.getBestOperation(rasv, dense, Functions.MULT).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignAllLoopInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(rasv, dense, Functions.DIV).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(rasv, dense, Functions.SECOND_LEFT_ZERO).getClass()); -// } -// -// @Test -// public void sasvRasvInteractions() { -// replayAll(); -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(sasv, rasv, Functions.PLUS).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(sasv, rasv, Functions.MINUS).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThisLookupThat.class, -// VectorBinaryAssign.getBestOperation(sasv, rasv, Functions.MULT).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignAllIterateThisLookupThatMergeUpdates.class, -// VectorBinaryAssign.getBestOperation(sasv, rasv, Functions.DIV).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(sasv, rasv, Functions.SECOND_LEFT_ZERO).getClass()); -// } -// -// @Test -// public void rasvSasvInteractions() { -// replayAll(); -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(rasv, sasv, Functions.PLUS).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(rasv, sasv, Functions.MINUS).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThisLookupThat.class, -// VectorBinaryAssign.getBestOperation(rasv, sasv, Functions.MULT).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignAllIterateThatLookupThisInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(rasv, sasv, Functions.DIV).getClass()); -// -// assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, -// VectorBinaryAssign.getBestOperation(rasv, sasv, Functions.SECOND_LEFT_ZERO).getClass()); -// } -// -// -// private void replayAll() { -// replay(dense, sasv, rasv); -// } -//} + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license + * agreements. See the NOTICE file distributed with this work for additional information regarding + * copyright ownership. The ASF licenses this file to You under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package org.apache.mahout.math; + +import org.apache.mahout.math.function.Functions; +import org.easymock.EasyMock; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.replay; +import static org.junit.Assert.assertEquals; + +@RunWith(JUnit4.class) +public final class VectorBinaryAssignCostTest { + RandomAccessSparseVector realRasv = new RandomAccessSparseVector(1000000); + SequentialAccessSparseVector realSasv = new SequentialAccessSparseVector(1000000); + DenseVector realDense = new DenseVector(1000000); + + Vector rasv = EasyMock.createMock(Vector.class); + Vector sasv = EasyMock.createMock(Vector.class); + Vector dense = EasyMock.createMock(Vector.class); + + private static void createStubs(Vector v, Vector realV) { + expect(v.getLookupCost()) + .andStubReturn(realV instanceof SequentialAccessSparseVector + ? Math.round(Math.log(1000)) : realV.getLookupCost()); + expect(v.getIteratorAdvanceCost()) + .andStubReturn(realV.getIteratorAdvanceCost()); + expect(v.isAddConstantTime()) + .andStubReturn(realV.isAddConstantTime()); + expect(v.isSequentialAccess()) + .andStubReturn(realV.isSequentialAccess()); + expect(v.isDense()) + .andStubReturn(realV.isDense()); + expect(v.getNumNondefaultElements()) + .andStubReturn(realV.isDense() ? realV.size() : 1000); + expect(v.size()) + .andStubReturn(realV.size()); + } + + @Before + public void setUpStubs() { + createStubs(dense, realDense); + createStubs(sasv, realSasv); + createStubs(rasv, realRasv); + } + + @Test + public void denseInteractions() { + replayAll(); + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(dense, dense, Functions.PLUS).getClass()); + + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(dense, dense, Functions.MINUS).getClass()); + + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThisLookupThat.class, + VectorBinaryAssign.getBestOperation(dense, dense, Functions.MULT).getClass()); + + assertEquals(VectorBinaryAssign.AssignAllLoopInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(dense, dense, Functions.DIV).getClass()); + + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(dense, dense, Functions.SECOND_LEFT_ZERO).getClass()); + } + + @Test + public void sasvInteractions() { + replayAll(); + + assertEquals(VectorBinaryAssign.AssignIterateUnionSequentialMergeUpdates.class, + VectorBinaryAssign.getBestOperation(sasv, sasv, Functions.PLUS).getClass()); + + assertEquals(VectorBinaryAssign.AssignIterateUnionSequentialMergeUpdates.class, + VectorBinaryAssign.getBestOperation(sasv, sasv, Functions.MINUS).getClass()); + + assertEquals(VectorBinaryAssign.AssignIterateUnionSequentialMergeUpdates.class, + VectorBinaryAssign.getBestOperation(sasv, sasv, Functions.MULT).getClass()); + + assertEquals(VectorBinaryAssign.AssignAllIterateSequentialMergeUpdates.class, + VectorBinaryAssign.getBestOperation(sasv, sasv, Functions.DIV).getClass()); + + assertEquals(VectorBinaryAssign.AssignIterateUnionSequentialMergeUpdates.class, + VectorBinaryAssign.getBestOperation(sasv, sasv, Functions.SECOND_LEFT_ZERO).getClass()); + } + + @Test + public void rasvInteractions() { + replayAll(); + + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(rasv, rasv, Functions.PLUS).getClass()); + + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(rasv, rasv, Functions.MINUS).getClass()); + + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThisLookupThat.class, + VectorBinaryAssign.getBestOperation(rasv, rasv, Functions.MULT).getClass()); + + assertEquals(VectorBinaryAssign.AssignAllLoopInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(rasv, rasv, Functions.DIV).getClass()); + + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(rasv, rasv, Functions.SECOND_LEFT_ZERO).getClass()); + } + + @Test + public void sasvDenseInteractions() { + replayAll(); + + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisMergeUpdates.class, + VectorBinaryAssign.getBestOperation(sasv, dense, Functions.PLUS).getClass()); + + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisMergeUpdates.class, + VectorBinaryAssign.getBestOperation(sasv, dense, Functions.MINUS).getClass()); + + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThisLookupThat.class, + VectorBinaryAssign.getBestOperation(sasv, dense, Functions.MULT).getClass()); + + assertEquals(VectorBinaryAssign.AssignAllIterateThisLookupThatMergeUpdates.class, + VectorBinaryAssign.getBestOperation(sasv, dense, Functions.DIV).getClass()); + + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisMergeUpdates.class, + VectorBinaryAssign.getBestOperation(sasv, dense, Functions.SECOND_LEFT_ZERO).getClass()); + } + + @Test + public void denseSasvInteractions() { + replayAll(); + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(dense, sasv, Functions.PLUS).getClass()); + + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(dense, sasv, Functions.MINUS).getClass()); + + assertEquals(VectorBinaryAssign.AssignIterateUnionSequentialInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(dense, sasv, Functions.MULT).getClass()); + + assertEquals(VectorBinaryAssign.AssignAllIterateThatLookupThisInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(dense, sasv, Functions.DIV).getClass()); + + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(dense, sasv, Functions.SECOND_LEFT_ZERO).getClass()); + } + + @Test + public void denseRasvInteractions() { + replayAll(); + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(dense, rasv, Functions.PLUS).getClass()); + + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(dense, rasv, Functions.MINUS).getClass()); + + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThisLookupThat.class, + VectorBinaryAssign.getBestOperation(dense, rasv, Functions.MULT).getClass()); + + assertEquals(VectorBinaryAssign.AssignAllLoopInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(dense, rasv, Functions.DIV).getClass()); + + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(dense, rasv, Functions.SECOND_LEFT_ZERO).getClass()); + } + + @Test + public void rasvDenseInteractions() { + replayAll(); + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(rasv, dense, Functions.PLUS).getClass()); + + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(rasv, dense, Functions.MINUS).getClass()); + + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThisLookupThat.class, + VectorBinaryAssign.getBestOperation(rasv, dense, Functions.MULT).getClass()); + + assertEquals(VectorBinaryAssign.AssignAllLoopInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(rasv, dense, Functions.DIV).getClass()); + + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(rasv, dense, Functions.SECOND_LEFT_ZERO).getClass()); + } + + @Test + public void sasvRasvInteractions() { + replayAll(); + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(sasv, rasv, Functions.PLUS).getClass()); + + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(sasv, rasv, Functions.MINUS).getClass()); + + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThisLookupThat.class, + VectorBinaryAssign.getBestOperation(sasv, rasv, Functions.MULT).getClass()); + + assertEquals(VectorBinaryAssign.AssignAllIterateThisLookupThatMergeUpdates.class, + VectorBinaryAssign.getBestOperation(sasv, rasv, Functions.DIV).getClass()); + + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(sasv, rasv, Functions.SECOND_LEFT_ZERO).getClass()); + } + + @Test + public void rasvSasvInteractions() { + replayAll(); + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(rasv, sasv, Functions.PLUS).getClass()); + + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(rasv, sasv, Functions.MINUS).getClass()); + + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThisLookupThat.class, + VectorBinaryAssign.getBestOperation(rasv, sasv, Functions.MULT).getClass()); + + assertEquals(VectorBinaryAssign.AssignAllIterateThatLookupThisInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(rasv, sasv, Functions.DIV).getClass()); + + assertEquals(VectorBinaryAssign.AssignNonzerosIterateThatLookupThisInplaceUpdates.class, + VectorBinaryAssign.getBestOperation(rasv, sasv, Functions.SECOND_LEFT_ZERO).getClass()); + } + + + private void replayAll() { + replay(dense, sasv, rasv); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/49ad8cb4/core/src/test/java/org/apache/mahout/math/VectorTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/VectorTest.java b/core/src/test/java/org/apache/mahout/math/VectorTest.java index 2627617..7f5f66e 100644 --- a/core/src/test/java/org/apache/mahout/math/VectorTest.java +++ b/core/src/test/java/org/apache/mahout/math/VectorTest.java @@ -28,7 +28,6 @@ import org.apache.mahout.math.function.Functions; import org.junit.Test; import com.google.common.collect.Sets; -import static org.junit.Assert.*; public final class VectorTest extends MahoutTestCase { http://git-wip-us.apache.org/repos/asf/mahout/blob/49ad8cb4/core/src/test/java/org/apache/mahout/math/WeightedVectorTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/WeightedVectorTest.java b/core/src/test/java/org/apache/mahout/math/WeightedVectorTest.java index e430123..a22014b 100644 --- a/core/src/test/java/org/apache/mahout/math/WeightedVectorTest.java +++ b/core/src/test/java/org/apache/mahout/math/WeightedVectorTest.java @@ -19,7 +19,7 @@ package org.apache.mahout.math; import org.apache.mahout.math.function.Functions; import org.junit.Test; -import static org.junit.Assert.*; + public class WeightedVectorTest extends AbstractVectorTest { @Test http://git-wip-us.apache.org/repos/asf/mahout/blob/49ad8cb4/core/src/test/java/org/apache/mahout/math/jet/random/ExponentialTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/jet/random/ExponentialTest.java b/core/src/test/java/org/apache/mahout/math/jet/random/ExponentialTest.java index 2804843..7d1e8c8 100644 --- a/core/src/test/java/org/apache/mahout/math/jet/random/ExponentialTest.java +++ b/core/src/test/java/org/apache/mahout/math/jet/random/ExponentialTest.java @@ -22,7 +22,6 @@ import org.apache.mahout.math.MahoutTestCase; import org.junit.Test; import java.util.Arrays; -import static org.junit.Assert.*; public final class ExponentialTest extends MahoutTestCase { http://git-wip-us.apache.org/repos/asf/mahout/blob/49ad8cb4/core/src/test/java/org/apache/mahout/math/jet/random/GammaTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/jet/random/GammaTest.java b/core/src/test/java/org/apache/mahout/math/jet/random/GammaTest.java index 6ef9ff4..55e4d3a 100644 --- a/core/src/test/java/org/apache/mahout/math/jet/random/GammaTest.java +++ b/core/src/test/java/org/apache/mahout/math/jet/random/GammaTest.java @@ -24,7 +24,6 @@ import org.junit.Test; import java.util.Arrays; import java.util.Locale; import java.util.Random; -import static org.junit.Assert.*; public final class GammaTest extends MahoutTestCase { http://git-wip-us.apache.org/repos/asf/mahout/blob/49ad8cb4/core/src/test/java/org/apache/mahout/math/jet/random/NegativeBinomialTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/jet/random/NegativeBinomialTest.java b/core/src/test/java/org/apache/mahout/math/jet/random/NegativeBinomialTest.java index 85aaa78..66d0029 100644 --- a/core/src/test/java/org/apache/mahout/math/jet/random/NegativeBinomialTest.java +++ b/core/src/test/java/org/apache/mahout/math/jet/random/NegativeBinomialTest.java @@ -28,7 +28,6 @@ import org.apache.mahout.math.MahoutTestCase; import org.junit.Test; import java.io.InputStreamReader; -import static org.junit.Assert.*; public final class NegativeBinomialTest extends MahoutTestCase { http://git-wip-us.apache.org/repos/asf/mahout/blob/49ad8cb4/core/src/test/java/org/apache/mahout/math/jet/random/NormalTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/jet/random/NormalTest.java b/core/src/test/java/org/apache/mahout/math/jet/random/NormalTest.java index a18bb2f..d61a799 100644 --- a/core/src/test/java/org/apache/mahout/math/jet/random/NormalTest.java +++ b/core/src/test/java/org/apache/mahout/math/jet/random/NormalTest.java @@ -22,7 +22,6 @@ import org.apache.mahout.math.MahoutTestCase; import org.junit.Test; import java.util.Random; -import static org.junit.Assert.*; public final class NormalTest extends MahoutTestCase { http://git-wip-us.apache.org/repos/asf/mahout/blob/49ad8cb4/core/src/test/java/org/apache/mahout/math/jet/random/engine/MersenneTwisterTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/jet/random/engine/MersenneTwisterTest.java b/core/src/test/java/org/apache/mahout/math/jet/random/engine/MersenneTwisterTest.java index ff58087..497cb90 100644 --- a/core/src/test/java/org/apache/mahout/math/jet/random/engine/MersenneTwisterTest.java +++ b/core/src/test/java/org/apache/mahout/math/jet/random/engine/MersenneTwisterTest.java @@ -21,7 +21,6 @@ import org.apache.mahout.math.MahoutTestCase; import org.junit.Test; import java.util.Date; -import static org.junit.Assert.*; /** * Tests the Mersenne Twister against the reference implementation 991029/mt19937-2.c which can be http://git-wip-us.apache.org/repos/asf/mahout/blob/49ad8cb4/core/src/test/java/org/apache/mahout/math/list/ObjectArrayListTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/list/ObjectArrayListTest.java b/core/src/test/java/org/apache/mahout/math/list/ObjectArrayListTest.java index 3d8e622..5709c6e 100644 --- a/core/src/test/java/org/apache/mahout/math/list/ObjectArrayListTest.java +++ b/core/src/test/java/org/apache/mahout/math/list/ObjectArrayListTest.java @@ -19,7 +19,6 @@ package org.apache.mahout.math.list; import org.apache.mahout.math.MahoutTestCase; import org.junit.Test; -import static org.junit.Assert.*; /** tests for {@link ObjectArrayList}*/ public class ObjectArrayListTest extends MahoutTestCase { http://git-wip-us.apache.org/repos/asf/mahout/blob/49ad8cb4/core/src/test/java/org/apache/mahout/math/random/EmpiricalTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/random/EmpiricalTest.java b/core/src/test/java/org/apache/mahout/math/random/EmpiricalTest.java index d242272..b155f04 100644 --- a/core/src/test/java/org/apache/mahout/math/random/EmpiricalTest.java +++ b/core/src/test/java/org/apache/mahout/math/random/EmpiricalTest.java @@ -1,78 +1,78 @@ -///* -// * Licensed to the Apache Software Foundation (ASF) under one or more -// * contributor license agreements. See the NOTICE file distributed with -// * this work for additional information regarding copyright ownership. -// * The ASF licenses this file to You under the Apache License, Version 2.0 -// * (the "License"); you may not use this file except in compliance with -// * the License. You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -//package org.apache.mahout.math.random; -// -//import com.google.common.collect.Lists; -//import org.apache.mahout.common.RandomUtils; -//import org.apache.mahout.math.MahoutTestCase; -//import org.junit.Assert; -//import org.junit.Test; -// -//import java.util.Collections; -//import java.util.List; -// -//public class EmpiricalTest extends MahoutTestCase { -// @Test -// public void testSimpleDist() { -// RandomUtils.useTestSeed(); -// -// Empirical z = new Empirical(true, true, 3, 0, 1, 0.5, 2, 1, 3.0); -// List<Double> r = Lists.newArrayList(); -// for (int i = 0; i < 10001; i++) { -// r.add(z.sample()); -// } -// Collections.sort(r); -// assertEquals(2.0, r.get(5000), 0.15); -// } -// -// @Test -// public void testZeros() { -// Empirical z = new Empirical(true, true, 3, 0, 1, 0.5, 2, 1, 3.0); -// assertEquals(-16.52, z.sample(0), 1.0e-2); -// assertEquals(20.47, z.sample(1), 1.0e-2); -// } -// -// @Test -// public void testBadArguments() { -// try { -// new Empirical(true, false, 20, 0, 1, 0.5, 2, 0.9, 9, 0.99, 10.0); -// Assert.fail("Should have caught that"); -// } catch (IllegalArgumentException e) { -// } -// try { -// new Empirical(false, true, 20, 0.1, 1, 0.5, 2, 0.9, 9, 1, 10.0); -// Assert.fail("Should have caught that"); -// } catch (IllegalArgumentException e) { -// } -// try { -// new Empirical(true, true, 20, -0.1, 1, 0.5, 2, 0.9, 9, 1, 10.0); -// Assert.fail("Should have caught that"); -// } catch (IllegalArgumentException e) { -// } -// try { -// new Empirical(true, true, 20, 0, 1, 0.5, 2, 0.9, 9, 1.2, 10.0); -// Assert.fail("Should have caught that"); -// } catch (IllegalArgumentException e) { -// } -// try { -// new Empirical(true, true, 20, 0, 1, 0.5, 2, 0.4, 9, 1, 10.0); -// Assert.fail("Should have caught that"); -// } catch (IllegalArgumentException e) { -// } -// } -//} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.random; + +import com.google.common.collect.Lists; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.MahoutTestCase; +import org.junit.Assert; +import org.junit.Test; + +import java.util.Collections; +import java.util.List; + +public class EmpiricalTest extends MahoutTestCase { + @Test + public void testSimpleDist() { + RandomUtils.useTestSeed(); + + Empirical z = new Empirical(true, true, 3, 0, 1, 0.5, 2, 1, 3.0); + List<Double> r = Lists.newArrayList(); + for (int i = 0; i < 10001; i++) { + r.add(z.sample()); + } + Collections.sort(r); + assertEquals(2.0, r.get(5000), 0.15); + } + + @Test + public void testZeros() { + Empirical z = new Empirical(true, true, 3, 0, 1, 0.5, 2, 1, 3.0); + assertEquals(-16.52, z.sample(0), 1.0e-2); + assertEquals(20.47, z.sample(1), 1.0e-2); + } + + @Test + public void testBadArguments() { + try { + new Empirical(true, false, 20, 0, 1, 0.5, 2, 0.9, 9, 0.99, 10.0); + Assert.fail("Should have caught that"); + } catch (IllegalArgumentException e) { + } + try { + new Empirical(false, true, 20, 0.1, 1, 0.5, 2, 0.9, 9, 1, 10.0); + Assert.fail("Should have caught that"); + } catch (IllegalArgumentException e) { + } + try { + new Empirical(true, true, 20, -0.1, 1, 0.5, 2, 0.9, 9, 1, 10.0); + Assert.fail("Should have caught that"); + } catch (IllegalArgumentException e) { + } + try { + new Empirical(true, true, 20, 0, 1, 0.5, 2, 0.9, 9, 1.2, 10.0); + Assert.fail("Should have caught that"); + } catch (IllegalArgumentException e) { + } + try { + new Empirical(true, true, 20, 0, 1, 0.5, 2, 0.4, 9, 1, 10.0); + Assert.fail("Should have caught that"); + } catch (IllegalArgumentException e) { + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/49ad8cb4/core/src/test/java/org/apache/mahout/math/random/IndianBuffetTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/random/IndianBuffetTest.java b/core/src/test/java/org/apache/mahout/math/random/IndianBuffetTest.java index f89981e..6b349c7 100644 --- a/core/src/test/java/org/apache/mahout/math/random/IndianBuffetTest.java +++ b/core/src/test/java/org/apache/mahout/math/random/IndianBuffetTest.java @@ -1,43 +1,43 @@ -///* -// * Licensed to the Apache Software Foundation (ASF) under one or more -// * contributor license agreements. See the NOTICE file distributed with -// * this work for additional information regarding copyright ownership. -// * The ASF licenses this file to You under the Apache License, Version 2.0 -// * (the "License"); you may not use this file except in compliance with -// * the License. You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -//package org.apache.mahout.math.random; -// -//import com.google.common.collect.HashMultiset; -//import com.google.common.collect.Multiset; -//import org.apache.mahout.common.RandomUtils; -//import org.junit.Test; -// -//import java.util.List; -// -//public class IndianBuffetTest { -// @Test -// public void testBasicText() { -// RandomUtils.useTestSeed(); -// IndianBuffet<String> sampler = IndianBuffet.createTextDocumentSampler(30); -// Multiset<String> counts = HashMultiset.create(); -// int[] lengths = new int[100]; -// for (int i = 0; i < 30; i++) { -// final List<String> doc = sampler.sample(); -// lengths[doc.size()]++; -// for (String w : doc) { -// counts.add(w); -// } -// System.out.printf("%s\n", doc); -// } -// } -//} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.random; + +import com.google.common.collect.HashMultiset; +import com.google.common.collect.Multiset; +import org.apache.mahout.common.RandomUtils; +import org.junit.Test; + +import java.util.List; + +public class IndianBuffetTest { + @Test + public void testBasicText() { + RandomUtils.useTestSeed(); + IndianBuffet<String> sampler = IndianBuffet.createTextDocumentSampler(30); + Multiset<String> counts = HashMultiset.create(); + int[] lengths = new int[100]; + for (int i = 0; i < 30; i++) { + final List<String> doc = sampler.sample(); + lengths[doc.size()]++; + for (String w : doc) { + counts.add(w); + } + System.out.printf("%s\n", doc); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/49ad8cb4/core/src/test/java/org/apache/mahout/math/random/MultiNormalTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/random/MultiNormalTest.java b/core/src/test/java/org/apache/mahout/math/random/MultiNormalTest.java index 2d20811..1d41dce 100644 --- a/core/src/test/java/org/apache/mahout/math/random/MultiNormalTest.java +++ b/core/src/test/java/org/apache/mahout/math/random/MultiNormalTest.java @@ -1,81 +1,81 @@ -///* -// * Licensed to the Apache Software Foundation (ASF) under one or more -// * contributor license agreements. See the NOTICE file distributed with -// * this work for additional information regarding copyright ownership. -// * The ASF licenses this file to You under the Apache License, Version 2.0 -// * (the "License"); you may not use this file except in compliance with -// * the License. You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -//package org.apache.mahout.math.random; -// -//import org.apache.mahout.common.RandomUtils; -//import org.apache.mahout.math.DenseVector; -//import org.apache.mahout.math.MahoutTestCase; -//import org.apache.mahout.math.Vector; -//import org.apache.mahout.math.stats.OnlineSummarizer; -//import org.junit.Before; -//import org.junit.Test; -// -//public class MultiNormalTest extends MahoutTestCase { -// @Override -// @Before -// public void setUp() { -// RandomUtils.useTestSeed(); -// } -// -// @Test -// public void testDiagonal() { -// DenseVector offset = new DenseVector(new double[]{6, 3, 0}); -// MultiNormal n = new MultiNormal( -// new DenseVector(new double[]{1, 2, 5}), offset); -// -// OnlineSummarizer[] s = { -// new OnlineSummarizer(), -// new OnlineSummarizer(), -// new OnlineSummarizer() -// }; -// -// OnlineSummarizer[] cross = { -// new OnlineSummarizer(), -// new OnlineSummarizer(), -// new OnlineSummarizer() -// }; -// -// for (int i = 0; i < 10000; i++) { -// Vector v = n.sample(); -// for (int j = 0; j < 3; j++) { -// s[j].add(v.get(j) - offset.get(j)); -// int k1 = j % 2; -// int k2 = (j + 1) / 2 + 1; -// cross[j].add((v.get(k1) - offset.get(k1)) * (v.get(k2) - offset.get(k2))); -// } -// } -// -// for (int j = 0; j < 3; j++) { -// assertEquals(0, s[j].getMean() / s[j].getSD(), 0.04); -// assertEquals(0, cross[j].getMean() / cross[j].getSD(), 0.04); -// } -// } -// -// -// @Test -// public void testRadius() { -// MultiNormal gen = new MultiNormal(0.1, new DenseVector(10)); -// OnlineSummarizer s = new OnlineSummarizer(); -// for (int i = 0; i < 10000; i++) { -// double x = gen.sample().norm(2) / Math.sqrt(10); -// s.add(x); -// } -// assertEquals(0.1, s.getMean(), 0.01); -// -// } -//} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.random; + +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.MahoutTestCase; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.stats.OnlineSummarizer; +import org.junit.Before; +import org.junit.Test; + +public class MultiNormalTest extends MahoutTestCase { + @Override + @Before + public void setUp() { + RandomUtils.useTestSeed(); + } + + @Test + public void testDiagonal() { + DenseVector offset = new DenseVector(new double[]{6, 3, 0}); + MultiNormal n = new MultiNormal( + new DenseVector(new double[]{1, 2, 5}), offset); + + OnlineSummarizer[] s = { + new OnlineSummarizer(), + new OnlineSummarizer(), + new OnlineSummarizer() + }; + + OnlineSummarizer[] cross = { + new OnlineSummarizer(), + new OnlineSummarizer(), + new OnlineSummarizer() + }; + + for (int i = 0; i < 10000; i++) { + Vector v = n.sample(); + for (int j = 0; j < 3; j++) { + s[j].add(v.get(j) - offset.get(j)); + int k1 = j % 2; + int k2 = (j + 1) / 2 + 1; + cross[j].add((v.get(k1) - offset.get(k1)) * (v.get(k2) - offset.get(k2))); + } + } + + for (int j = 0; j < 3; j++) { + assertEquals(0, s[j].getMean() / s[j].getSD(), 0.04); + assertEquals(0, cross[j].getMean() / cross[j].getSD(), 0.04); + } + } + + + @Test + public void testRadius() { + MultiNormal gen = new MultiNormal(0.1, new DenseVector(10)); + OnlineSummarizer s = new OnlineSummarizer(); + for (int i = 0; i < 10000; i++) { + double x = gen.sample().norm(2) / Math.sqrt(10); + s.add(x); + } + assertEquals(0.1, s.getMean(), 0.01); + + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/49ad8cb4/core/src/test/java/org/apache/mahout/math/random/MultinomialTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/random/MultinomialTest.java b/core/src/test/java/org/apache/mahout/math/random/MultinomialTest.java index b0f84e8..fe4543f 100644 --- a/core/src/test/java/org/apache/mahout/math/random/MultinomialTest.java +++ b/core/src/test/java/org/apache/mahout/math/random/MultinomialTest.java @@ -1,269 +1,269 @@ -///* -// * Licensed to the Apache Software Foundation (ASF) under one or more -// * contributor license agreements. See the NOTICE file distributed with -// * this work for additional information regarding copyright ownership. -// * The ASF licenses this file to You under the Apache License, Version 2.0 -// * (the "License"); you may not use this file except in compliance with -// * the License. You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -//package org.apache.mahout.math.random; -// -//import java.util.List; -//import java.util.Map; -//import java.util.Random; -// -//import com.google.common.collect.HashMultiset; -//import com.google.common.collect.ImmutableMap; -//import com.google.common.collect.Lists; -//import com.google.common.collect.Multiset; -//import org.apache.mahout.common.RandomUtils; -//import org.apache.mahout.math.MahoutTestCase; -//import org.junit.Before; -//import org.junit.Test; -// -//public class MultinomialTest extends MahoutTestCase { -// @Override -// @Before -// public void setUp() { -// RandomUtils.useTestSeed(); -// } -// -// @Test(expected = IllegalArgumentException.class) -// public void testNoValues() { -// Multiset<String> emptySet = HashMultiset.create(); -// new Multinomial<>(emptySet); -// } -// -// @Test -// public void testSingleton() { -// Multiset<String> oneThing = HashMultiset.create(); -// oneThing.add("one"); -// Multinomial<String> s = new Multinomial<>(oneThing); -// assertEquals("one", s.sample(0)); -// assertEquals("one", s.sample(0.1)); -// assertEquals("one", s.sample(1)); -// } -// -// @Test -// public void testEvenSplit() { -// Multiset<String> stuff = HashMultiset.create(); -// for (int i = 0; i < 5; i++) { -// stuff.add(String.valueOf(i)); -// } -// Multinomial<String> s = new Multinomial<>(stuff); -// double EPSILON = 1.0e-15; -// -// Multiset<String> cnt = HashMultiset.create(); -// for (int i = 0; i < 5; i++) { -// cnt.add(s.sample(i * 0.2)); -// cnt.add(s.sample(i * 0.2 + EPSILON)); -// cnt.add(s.sample((i + 1) * 0.2 - EPSILON)); -// } -// -// assertEquals(5, cnt.elementSet().size()); -// for (String v : cnt.elementSet()) { -// assertEquals(3, cnt.count(v), 1.01); -// } -// assertTrue(cnt.contains(s.sample(1))); -// assertEquals(s.sample(1 - EPSILON), s.sample(1)); -// } -// -// @Test -// public void testPrime() { -// List<String> data = Lists.newArrayList(); -// for (int i = 0; i < 17; i++) { -// String s = "0"; -// if ((i & 1) != 0) { -// s = "1"; -// } -// if ((i & 2) != 0) { -// s = "2"; -// } -// if ((i & 4) != 0) { -// s = "3"; -// } -// if ((i & 8) != 0) { -// s = "4"; -// } -// data.add(s); -// } -// -// Multiset<String> stuff = HashMultiset.create(); -// -// for (String x : data) { -// stuff.add(x); -// } -// -// Multinomial<String> s0 = new Multinomial<>(stuff); -// Multinomial<String> s1 = new Multinomial<>(stuff); -// Multinomial<String> s2 = new Multinomial<>(stuff); -// double EPSILON = 1.0e-15; -// -// Multiset<String> cnt = HashMultiset.create(); -// for (int i = 0; i < 50; i++) { -// double p0 = i * 0.02; -// double p1 = (i + 1) * 0.02; -// cnt.add(s0.sample(p0)); -// cnt.add(s0.sample(p0 + EPSILON)); -// cnt.add(s0.sample(p1 - EPSILON)); -// -// assertEquals(s0.sample(p0), s1.sample(p0)); -// assertEquals(s0.sample(p0 + EPSILON), s1.sample(p0 + EPSILON)); -// assertEquals(s0.sample(p1 - EPSILON), s1.sample(p1 - EPSILON)); -// -// assertEquals(s0.sample(p0), s2.sample(p0)); -// assertEquals(s0.sample(p0 + EPSILON), s2.sample(p0 + EPSILON)); -// assertEquals(s0.sample(p1 - EPSILON), s2.sample(p1 - EPSILON)); -// } -// -// assertEquals(s0.sample(0), s1.sample(0)); -// assertEquals(s0.sample(0 + EPSILON), s1.sample(0 + EPSILON)); -// assertEquals(s0.sample(1 - EPSILON), s1.sample(1 - EPSILON)); -// assertEquals(s0.sample(1), s1.sample(1)); -// -// assertEquals(s0.sample(0), s2.sample(0)); -// assertEquals(s0.sample(0 + EPSILON), s2.sample(0 + EPSILON)); -// assertEquals(s0.sample(1 - EPSILON), s2.sample(1 - EPSILON)); -// assertEquals(s0.sample(1), s2.sample(1)); -// -// assertEquals(5, cnt.elementSet().size()); -// // regression test, really. These values depend on the original seed and exact algorithm. -// // the actual values should be within about 2 of these, however, almost regardless of seed -// Map<String, Integer> ref = ImmutableMap.of("3", 35, "2", 18, "1", 9, "0", 16, "4", 72); -// for (String v : cnt.elementSet()) { -// assertTrue(Math.abs(ref.get(v) - cnt.count(v)) <= 2); -// } -// -// assertTrue(cnt.contains(s0.sample(1))); -// assertEquals(s0.sample(1 - EPSILON), s0.sample(1)); -// } -// -// @Test -// public void testInsert() { -// Random rand = RandomUtils.getRandom(); -// Multinomial<Integer> table = new Multinomial<>(); -// -// double[] p = new double[10]; -// for (int i = 0; i < 10; i++) { -// p[i] = rand.nextDouble(); -// table.add(i, p[i]); -// } -// -// checkSelfConsistent(table); -// -// for (int i = 0; i < 10; i++) { -// assertEquals(p[i], table.getWeight(i), 0); -// } -// } -// -// @Test -// public void testSetZeroWhileIterating() { -// Multinomial<Integer> table = new Multinomial<>(); -// for (int i = 0; i < 10000; ++i) { -// table.add(i, i); -// } -// // Setting a sample's weight to 0 removes from the items map. -// // If that map is used when iterating (it used to be), it will -// // trigger a ConcurrentModificationException. -// for (Integer sample : table) { -// table.set(sample, 0); -// } -// } -// -// @Test(expected=NullPointerException.class) -// public void testNoNullValuesAllowed() { -// Multinomial<Integer> table = new Multinomial<>(); -// // No null values should be allowed. -// table.add(null, 1); -// } -// -// @Test -// public void testDeleteAndUpdate() { -// Random rand = RandomUtils.getRandom(); -// Multinomial<Integer> table = new Multinomial<>(); -// assertEquals(0, table.getWeight(), 1.0e-9); -// -// double total = 0; -// double[] p = new double[10]; -// for (int i = 0; i < 10; i++) { -// p[i] = rand.nextDouble(); -// table.add(i, p[i]); -// total += p[i]; -// assertEquals(total, table.getWeight(), 1.0e-9); -// } -// -// assertEquals(total, table.getWeight(), 1.0e-9); -// -// checkSelfConsistent(table); -// -// double delta = p[7] + p[8]; -// table.delete(7); -// p[7] = 0; -// -// table.set(8, 0); -// p[8] = 0; -// total -= delta; -// -// checkSelfConsistent(table); -// -// assertEquals(total, table.getWeight(), 1.0e-9); -// for (int i = 0; i < 10; i++) { -// assertEquals(p[i], table.getWeight(i), 0); -// assertEquals(p[i] / total, table.getProbability(i), 1.0e-10); -// } -// -// table.set(9, 5.1); -// total -= p[9]; -// p[9] = 5.1; -// total += 5.1; -// -// assertEquals(total , table.getWeight(), 1.0e-9); -// for (int i = 0; i < 10; i++) { -// assertEquals(p[i], table.getWeight(i), 0); -// assertEquals(p[i] / total, table.getProbability(i), 1.0e-10); -// } -// -// checkSelfConsistent(table); -// -// for (int i = 0; i < 10; i++) { -// assertEquals(p[i], table.getWeight(i), 0); -// } -// } -// -// private static void checkSelfConsistent(Multinomial<Integer> table) { -// List<Double> weights = table.getWeights(); -// -// double totalWeight = table.getWeight(); -// -// double p = 0; -// int[] k = new int[weights.size()]; -// for (double weight : weights) { -// if (weight > 0) { -// if (p > 0) { -// k[table.sample(p - 1.0e-9)]++; -// } -// k[table.sample(p + 1.0e-9)]++; -// } -// p += weight / totalWeight; -// } -// k[table.sample(p - 1.0e-9)]++; -// assertEquals(1, p, 1.0e-9); -// -// for (int i = 0; i < weights.size(); i++) { -// if (table.getWeight(i) > 0) { -// assertEquals(2, k[i]); -// } else { -// assertEquals(0, k[i]); -// } -// } -// } -//} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.random; + +import java.util.List; +import java.util.Map; +import java.util.Random; + +import com.google.common.collect.HashMultiset; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import com.google.common.collect.Multiset; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.MahoutTestCase; +import org.junit.Before; +import org.junit.Test; + +public class MultinomialTest extends MahoutTestCase { + @Override + @Before + public void setUp() { + RandomUtils.useTestSeed(); + } + + @Test(expected = IllegalArgumentException.class) + public void testNoValues() { + Multiset<String> emptySet = HashMultiset.create(); + new Multinomial<>(emptySet); + } + + @Test + public void testSingleton() { + Multiset<String> oneThing = HashMultiset.create(); + oneThing.add("one"); + Multinomial<String> s = new Multinomial<>(oneThing); + assertEquals("one", s.sample(0)); + assertEquals("one", s.sample(0.1)); + assertEquals("one", s.sample(1)); + } + + @Test + public void testEvenSplit() { + Multiset<String> stuff = HashMultiset.create(); + for (int i = 0; i < 5; i++) { + stuff.add(String.valueOf(i)); + } + Multinomial<String> s = new Multinomial<>(stuff); + double EPSILON = 1.0e-15; + + Multiset<String> cnt = HashMultiset.create(); + for (int i = 0; i < 5; i++) { + cnt.add(s.sample(i * 0.2)); + cnt.add(s.sample(i * 0.2 + EPSILON)); + cnt.add(s.sample((i + 1) * 0.2 - EPSILON)); + } + + assertEquals(5, cnt.elementSet().size()); + for (String v : cnt.elementSet()) { + assertEquals(3, cnt.count(v), 1.01); + } + assertTrue(cnt.contains(s.sample(1))); + assertEquals(s.sample(1 - EPSILON), s.sample(1)); + } + + @Test + public void testPrime() { + List<String> data = Lists.newArrayList(); + for (int i = 0; i < 17; i++) { + String s = "0"; + if ((i & 1) != 0) { + s = "1"; + } + if ((i & 2) != 0) { + s = "2"; + } + if ((i & 4) != 0) { + s = "3"; + } + if ((i & 8) != 0) { + s = "4"; + } + data.add(s); + } + + Multiset<String> stuff = HashMultiset.create(); + + for (String x : data) { + stuff.add(x); + } + + Multinomial<String> s0 = new Multinomial<>(stuff); + Multinomial<String> s1 = new Multinomial<>(stuff); + Multinomial<String> s2 = new Multinomial<>(stuff); + double EPSILON = 1.0e-15; + + Multiset<String> cnt = HashMultiset.create(); + for (int i = 0; i < 50; i++) { + double p0 = i * 0.02; + double p1 = (i + 1) * 0.02; + cnt.add(s0.sample(p0)); + cnt.add(s0.sample(p0 + EPSILON)); + cnt.add(s0.sample(p1 - EPSILON)); + + assertEquals(s0.sample(p0), s1.sample(p0)); + assertEquals(s0.sample(p0 + EPSILON), s1.sample(p0 + EPSILON)); + assertEquals(s0.sample(p1 - EPSILON), s1.sample(p1 - EPSILON)); + + assertEquals(s0.sample(p0), s2.sample(p0)); + assertEquals(s0.sample(p0 + EPSILON), s2.sample(p0 + EPSILON)); + assertEquals(s0.sample(p1 - EPSILON), s2.sample(p1 - EPSILON)); + } + + assertEquals(s0.sample(0), s1.sample(0)); + assertEquals(s0.sample(0 + EPSILON), s1.sample(0 + EPSILON)); + assertEquals(s0.sample(1 - EPSILON), s1.sample(1 - EPSILON)); + assertEquals(s0.sample(1), s1.sample(1)); + + assertEquals(s0.sample(0), s2.sample(0)); + assertEquals(s0.sample(0 + EPSILON), s2.sample(0 + EPSILON)); + assertEquals(s0.sample(1 - EPSILON), s2.sample(1 - EPSILON)); + assertEquals(s0.sample(1), s2.sample(1)); + + assertEquals(5, cnt.elementSet().size()); + // regression test, really. These values depend on the original seed and exact algorithm. + // the actual values should be within about 2 of these, however, almost regardless of seed + Map<String, Integer> ref = ImmutableMap.of("3", 35, "2", 18, "1", 9, "0", 16, "4", 72); + for (String v : cnt.elementSet()) { + assertTrue(Math.abs(ref.get(v) - cnt.count(v)) <= 2); + } + + assertTrue(cnt.contains(s0.sample(1))); + assertEquals(s0.sample(1 - EPSILON), s0.sample(1)); + } + + @Test + public void testInsert() { + Random rand = RandomUtils.getRandom(); + Multinomial<Integer> table = new Multinomial<>(); + + double[] p = new double[10]; + for (int i = 0; i < 10; i++) { + p[i] = rand.nextDouble(); + table.add(i, p[i]); + } + + checkSelfConsistent(table); + + for (int i = 0; i < 10; i++) { + assertEquals(p[i], table.getWeight(i), 0); + } + } + + @Test + public void testSetZeroWhileIterating() { + Multinomial<Integer> table = new Multinomial<>(); + for (int i = 0; i < 10000; ++i) { + table.add(i, i); + } + // Setting a sample's weight to 0 removes from the items map. + // If that map is used when iterating (it used to be), it will + // trigger a ConcurrentModificationException. + for (Integer sample : table) { + table.set(sample, 0); + } + } + + @Test(expected=NullPointerException.class) + public void testNoNullValuesAllowed() { + Multinomial<Integer> table = new Multinomial<>(); + // No null values should be allowed. + table.add(null, 1); + } + + @Test + public void testDeleteAndUpdate() { + Random rand = RandomUtils.getRandom(); + Multinomial<Integer> table = new Multinomial<>(); + assertEquals(0, table.getWeight(), 1.0e-9); + + double total = 0; + double[] p = new double[10]; + for (int i = 0; i < 10; i++) { + p[i] = rand.nextDouble(); + table.add(i, p[i]); + total += p[i]; + assertEquals(total, table.getWeight(), 1.0e-9); + } + + assertEquals(total, table.getWeight(), 1.0e-9); + + checkSelfConsistent(table); + + double delta = p[7] + p[8]; + table.delete(7); + p[7] = 0; + + table.set(8, 0); + p[8] = 0; + total -= delta; + + checkSelfConsistent(table); + + assertEquals(total, table.getWeight(), 1.0e-9); + for (int i = 0; i < 10; i++) { + assertEquals(p[i], table.getWeight(i), 0); + assertEquals(p[i] / total, table.getProbability(i), 1.0e-10); + } + + table.set(9, 5.1); + total -= p[9]; + p[9] = 5.1; + total += 5.1; + + assertEquals(total , table.getWeight(), 1.0e-9); + for (int i = 0; i < 10; i++) { + assertEquals(p[i], table.getWeight(i), 0); + assertEquals(p[i] / total, table.getProbability(i), 1.0e-10); + } + + checkSelfConsistent(table); + + for (int i = 0; i < 10; i++) { + assertEquals(p[i], table.getWeight(i), 0); + } + } + + private static void checkSelfConsistent(Multinomial<Integer> table) { + List<Double> weights = table.getWeights(); + + double totalWeight = table.getWeight(); + + double p = 0; + int[] k = new int[weights.size()]; + for (double weight : weights) { + if (weight > 0) { + if (p > 0) { + k[table.sample(p - 1.0e-9)]++; + } + k[table.sample(p + 1.0e-9)]++; + } + p += weight / totalWeight; + } + k[table.sample(p - 1.0e-9)]++; + assertEquals(1, p, 1.0e-9); + + for (int i = 0; i < weights.size(); i++) { + if (table.getWeight(i) > 0) { + assertEquals(2, k[i]); + } else { + assertEquals(0, k[i]); + } + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/49ad8cb4/core/src/test/java/org/apache/mahout/math/random/NormalTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/random/NormalTest.java b/core/src/test/java/org/apache/mahout/math/random/NormalTest.java index e96ef53..6263672 100644 --- a/core/src/test/java/org/apache/mahout/math/random/NormalTest.java +++ b/core/src/test/java/org/apache/mahout/math/random/NormalTest.java @@ -1,62 +1,62 @@ -///* -// * Licensed to the Apache Software Foundation (ASF) under one or more -// * contributor license agreements. See the NOTICE file distributed with -// * this work for additional information regarding copyright ownership. -// * The ASF licenses this file to You under the Apache License, Version 2.0 -// * (the "License"); you may not use this file except in compliance with -// * the License. You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -//package org.apache.mahout.math.random; -// -//import org.apache.commons.math3.distribution.NormalDistribution; -//import org.apache.mahout.common.RandomUtils; -//import org.apache.mahout.math.MahoutTestCase; -//import org.apache.mahout.math.stats.OnlineSummarizer; -//import org.junit.Before; -//import org.junit.Test; -// -//import java.util.Arrays; -// -//public final class NormalTest extends MahoutTestCase { -// -// @Override -// @Before -// public void setUp() { -// RandomUtils.useTestSeed(); -// } -// -// @Test -// public void testOffset() { -// OnlineSummarizer s = new OnlineSummarizer(); -// Sampler<Double> sampler = new Normal(2, 5); -// for (int i = 0; i < 10001; i++) { -// s.add(sampler.sample()); -// } -// assertEquals(String.format("m = %.3f, sd = %.3f", s.getMean(), s.getSD()), 2, s.getMean(), 0.04 * s.getSD()); -// assertEquals(5, s.getSD(), 0.12); -// } -// -// @Test -// public void testSample() throws Exception { -// double[] data = new double[10001]; -// Sampler<Double> sampler = new Normal(); -// for (int i = 0; i < data.length; i++) { -// data[i] = sampler.sample(); -// } -// Arrays.sort(data); -// -// NormalDistribution reference = new NormalDistribution(RandomUtils.getRandom().getRandomGenerator(), -// 0, 1, -// NormalDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY); -// assertEquals("Median", reference.inverseCumulativeProbability(0.5), data[5000], 0.04); -// } -//} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.random; + +import org.apache.commons.math3.distribution.NormalDistribution; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.MahoutTestCase; +import org.apache.mahout.math.stats.OnlineSummarizer; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; + +public final class NormalTest extends MahoutTestCase { + + @Override + @Before + public void setUp() { + RandomUtils.useTestSeed(); + } + + @Test + public void testOffset() { + OnlineSummarizer s = new OnlineSummarizer(); + Sampler<Double> sampler = new Normal(2, 5); + for (int i = 0; i < 10001; i++) { + s.add(sampler.sample()); + } + assertEquals(String.format("m = %.3f, sd = %.3f", s.getMean(), s.getSD()), 2, s.getMean(), 0.04 * s.getSD()); + assertEquals(5, s.getSD(), 0.12); + } + + @Test + public void testSample() throws Exception { + double[] data = new double[10001]; + Sampler<Double> sampler = new Normal(); + for (int i = 0; i < data.length; i++) { + data[i] = sampler.sample(); + } + Arrays.sort(data); + + NormalDistribution reference = new NormalDistribution(RandomUtils.getRandom().getRandomGenerator(), + 0, 1, + NormalDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY); + assertEquals("Median", reference.inverseCumulativeProbability(0.5), data[5000], 0.04); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/49ad8cb4/core/src/test/java/org/apache/mahout/math/random/PoissonSamplerTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/random/PoissonSamplerTest.java b/core/src/test/java/org/apache/mahout/math/random/PoissonSamplerTest.java index 2a88529..d4612f7 100644 --- a/core/src/test/java/org/apache/mahout/math/random/PoissonSamplerTest.java +++ b/core/src/test/java/org/apache/mahout/math/random/PoissonSamplerTest.java @@ -1,56 +1,56 @@ -///* -// * Licensed to the Apache Software Foundation (ASF) under one or more -// * contributor license agreements. See the NOTICE file distributed with -// * this work for additional information regarding copyright ownership. -// * The ASF licenses this file to You under the Apache License, Version 2.0 -// * (the "License"); you may not use this file except in compliance with -// * the License. You may obtain a copy of the License at -// * -// * http://www.apache.org/licenses/LICENSE-2.0 -// * -// * Unless required by applicable law or agreed to in writing, software -// * distributed under the License is distributed on an "AS IS" BASIS, -// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// * See the License for the specific language governing permissions and -// * limitations under the License. -// */ -// -//package org.apache.mahout.math.random; -// -//import org.apache.commons.math3.distribution.IntegerDistribution; -//import org.apache.commons.math3.distribution.PoissonDistribution; -//import org.apache.mahout.common.RandomUtils; -//import org.apache.mahout.math.MahoutTestCase; -//import org.junit.Before; -//import org.junit.Test; -// -//public final class PoissonSamplerTest extends MahoutTestCase { -// -// @Override -// @Before -// public void setUp() { -// RandomUtils.useTestSeed(); -// } -// -// @Test -// public void testBasics() { -// for (double alpha : new double[]{0.1, 1, 10, 100}) { -// checkDistribution(new PoissonSampler(alpha), alpha); -// } -// } -// -// private static void checkDistribution(Sampler<Double> pd, double alpha) { -// int[] count = new int[(int) Math.max(10, 5 * alpha)]; -// for (int i = 0; i < 10000; i++) { -// count[pd.sample().intValue()]++; -// } -// -// IntegerDistribution ref = new PoissonDistribution(RandomUtils.getRandom().getRandomGenerator(), -// alpha, -// PoissonDistribution.DEFAULT_EPSILON, -// PoissonDistribution.DEFAULT_MAX_ITERATIONS); -// for (int i = 0; i < count.length; i++) { -// assertEquals(ref.probability(i), count[i] / 10000.0, 2.0e-2); -// } -// } -//} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.random; + +import org.apache.commons.math3.distribution.IntegerDistribution; +import org.apache.commons.math3.distribution.PoissonDistribution; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.MahoutTestCase; +import org.junit.Before; +import org.junit.Test; + +public final class PoissonSamplerTest extends MahoutTestCase { + + @Override + @Before + public void setUp() { + RandomUtils.useTestSeed(); + } + + @Test + public void testBasics() { + for (double alpha : new double[]{0.1, 1, 10, 100}) { + checkDistribution(new PoissonSampler(alpha), alpha); + } + } + + private static void checkDistribution(Sampler<Double> pd, double alpha) { + int[] count = new int[(int) Math.max(10, 5 * alpha)]; + for (int i = 0; i < 10000; i++) { + count[pd.sample().intValue()]++; + } + + IntegerDistribution ref = new PoissonDistribution(RandomUtils.getRandom().getRandomGenerator(), + alpha, + PoissonDistribution.DEFAULT_EPSILON, + PoissonDistribution.DEFAULT_MAX_ITERATIONS); + for (int i = 0; i < count.length; i++) { + assertEquals(ref.probability(i), count[i] / 10000.0, 2.0e-2); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/49ad8cb4/engine/hdfs/src/main/java/org/apache/mahout/math/MatrixWritable.java ---------------------------------------------------------------------- diff --git a/engine/hdfs/src/main/java/org/apache/mahout/math/MatrixWritable.java b/engine/hdfs/src/main/java/org/apache/mahout/math/MatrixWritable.java index 04b7415..b8fc461 100644 --- a/engine/hdfs/src/main/java/org/apache/mahout/math/MatrixWritable.java +++ b/engine/hdfs/src/main/java/org/apache/mahout/math/MatrixWritable.java @@ -18,9 +18,8 @@ package org.apache.mahout.math; import com.google.common.base.Preconditions; -import it.unimi.dsi.fastutil.ints.IntArrayList; import org.apache.hadoop.io.Writable; -//import org.apache.mahout.math.list.IntArrayList; +import org.apache.mahout.math.list.IntArrayList; import java.io.DataInput; import java.io.DataOutput; @@ -190,7 +189,7 @@ public class MatrixWritable implements Writable { int numNonZeroRows = rowIndices.size(); out.writeInt(numNonZeroRows); for (int i = 0; i < numNonZeroRows; i++) { - int rowIndex = rowIndices.getInt(i); //.getQuick(i); + int rowIndex = rowIndices.getQuick(i); out.writeInt(rowIndex); VectorWritable.writeVectorContents(out, matrix.viewRow(rowIndex), vectorFlags); }
