Author: yxjiang
Date: Fri Jun 21 14:53:45 2013
New Revision: 1495461
URL: http://svn.apache.org/r1495461
Log:
HAMA-765: Add apply method to Vector/Matrix
Added:
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/CrossEntropy.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleFunction.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleFunction.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/Function.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/FunctionFactory.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/Sigmoid.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/SquaredError.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/Tanh.java
hama/trunk/ml/src/test/java/org/apache/hama/ml/math/
hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java
hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java
hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestFunctionFactory.java
Removed:
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CostFunction.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CostFunctionFactory.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CrossEntropy.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/LogisticCostFunction.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/Sigmoid.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SquaredError.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SquashingFunction.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SquashingFunctionFactory.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/Tanh.java
Modified:
hama/trunk/CHANGES.txt
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleVectorFunction.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleVectorFunction.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMultiLayerPerceptron.java
hama/trunk/ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java
Modified: hama/trunk/CHANGES.txt
URL:
http://svn.apache.org/viewvc/hama/trunk/CHANGES.txt?rev=1495461&r1=1495460&r2=1495461&view=diff
==============================================================================
--- hama/trunk/CHANGES.txt (original)
+++ hama/trunk/CHANGES.txt Fri Jun 21 14:53:45 2013
@@ -22,6 +22,7 @@ Release 0.7 (unreleased changes)
HAMA-754: PartitioningRunner should write raw records to partition files
(edwardyoon)
HAMA-707: BSPMessageBundle should be able to encapsulate messages
serialized in ByteBuffer (surajsmenon)
HAMA-722: Messaging queue should construct sender and receiver queue
(surajsmenon)
+ HAMA-765: Add apply method to Vector/Matrix (Yexi Jiang)
Release 0.6.1 - April 01, 2013
Added: hama/trunk/ml/src/main/java/org/apache/hama/ml/math/CrossEntropy.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/math/CrossEntropy.java?rev=1495461&view=auto
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/math/CrossEntropy.java
(added)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/math/CrossEntropy.java Fri
Jun 21 14:53:45 2013
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.math;
+
+/**
+ * The cross entropy cost function.
+ *
+ * <pre>
+ * cost(t, y) = - t * log(y) - (1 - t) * log(1 - y),
+ * where t denotes the target value, y denotes the estimated value.
+ * </pre>
+ */
+public class CrossEntropy extends DoubleDoubleFunction {
+
+ @Override
+ public double apply(double target, double actual) {
+ return -target * Math.log(actual) - (1 - target) * Math.log(1 - actual);
+ }
+
+ @Override
+ public double applyDerivative(double target, double actual) {
+ double adjustedTarget = target;
+ double adjustedActual = actual;
+ if (adjustedActual == 1) {
+ adjustedActual = 0.999;
+ } else if (actual == 0) {
+ adjustedActual = 0.001;
+ }
+ if (adjustedTarget == 1) {
+ adjustedTarget = 0.999;
+ } else if (adjustedTarget == 0) {
+ adjustedTarget = 0.001;
+ }
+ return -adjustedTarget / adjustedActual + (1 - adjustedTarget)
+ / (1 - adjustedActual);
+ }
+
+}
Modified:
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java?rev=1495461&r1=1495460&r2=1495461&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java
(original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java
Fri Jun 21 14:53:45 2013
@@ -778,4 +778,37 @@ public final class DenseDoubleMatrix imp
return a.subtract(b).sum();
}
+ @Override
+ /**
+ * {@inheritDoc}
+ */
+ public DoubleMatrix applyToElements(DoubleFunction fun) {
+ for (int r = 0; r < this.numRows; ++r) {
+ for (int c = 0; c < this.numColumns; ++c) {
+ this.set(r, c, fun.apply(this.get(r, c)));
+ }
+ }
+ return this;
+ }
+
+ @Override
+ /**
+ * {@inheritDoc}
+ */
+ public DoubleMatrix applyToElements(DoubleMatrix other, DoubleDoubleFunction
fun) {
+ if (this.numRows != other.getRowCount()
+ || this.numColumns != other.getColumnCount()) {
+ throw new IllegalArgumentException(
+ "Cannot apply double double function to matrices with different
sizes.");
+ }
+
+ for (int r = 0; r < this.numRows; ++r) {
+ for (int c = 0; c < this.numColumns; ++c) {
+ this.set(r, c, fun.apply(this.get(r, c), other.get(r, c)));
+ }
+ }
+
+ return this;
+ }
+
}
Modified:
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java?rev=1495461&r1=1495460&r2=1495461&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java
(original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java
Fri Jun 21 14:53:45 2013
@@ -100,11 +100,34 @@ public final class DenseDoubleVector imp
vector[index] = value;
}
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public DoubleVector applyToElements(DoubleFunction func) {
+ for (int i = 0; i < vector.length; i++) {
+ this.vector[i] = func.apply(vector[i]);
+ }
+ return this;
+ }
+
+ /**
+ * {@inheritDoc}}
+ */
+ @Override
+ public DoubleVector applyToElements(DoubleVector other, DoubleDoubleFunction
func) {
+ for (int i = 0; i < vector.length; i++) {
+ this.vector[i] = func.apply(vector[i], other.get(i));
+ }
+ return this;
+ }
+
/*
* (non-Javadoc)
* @see de.jungblut.math.DoubleVector#apply(de.jungblut.math.function.
* DoubleVectorFunction)
*/
+ @Deprecated
@Override
public DoubleVector apply(DoubleVectorFunction func) {
DenseDoubleVector newV = new DenseDoubleVector(this.vector);
@@ -119,6 +142,7 @@ public final class DenseDoubleVector imp
* @see de.jungblut.math.DoubleVector#apply(de.jungblut.math.DoubleVector,
* de.jungblut.math.function.DoubleDoubleVectorFunction)
*/
+ @Deprecated
@Override
public DoubleVector apply(DoubleVector other, DoubleDoubleVectorFunction
func) {
DenseDoubleVector newV = (DenseDoubleVector) deepCopy();
Added:
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleFunction.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleFunction.java?rev=1495461&view=auto
==============================================================================
---
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleFunction.java
(added)
+++
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleFunction.java
Fri Jun 21 14:53:45 2013
@@ -0,0 +1,45 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.math;
+
+/**
+ * A double double function takes two arguments. A vector or matrix can apply
+ * the double function to each element.
+ *
+ */
+public abstract class DoubleDoubleFunction extends Function {
+
+ /**
+ * Apply the function to elements to two given arguments.
+ *
+ * @param x1
+ * @param x2
+ * @return The result based on the calculation on two arguments.
+ */
+ public abstract double apply(double x1, double x2);
+
+ /**
+ * Apply the derivative of this function to two given arguments.
+ *
+ * @param x1
+ * @param x2
+ * @return The result based on the calculation on two arguments.
+ */
+ public abstract double applyDerivative(double x1, double x2);
+
+}
Modified:
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleVectorFunction.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleVectorFunction.java?rev=1495461&r1=1495460&r2=1495461&view=diff
==============================================================================
---
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleVectorFunction.java
(original)
+++
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleVectorFunction.java
Fri Jun 21 14:53:45 2013
@@ -20,7 +20,10 @@ package org.apache.hama.ml.math;
/**
* A function that can be applied to two double vectors via {@link
DoubleVector}
* #apply({@link DoubleVector} v, {@link DoubleDoubleVectorFunction} f);
+ *
+ * This class will be replaced by {@link DoubleDoubleFunction}
*/
+@Deprecated
public interface DoubleDoubleVectorFunction {
/**
Added: hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleFunction.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleFunction.java?rev=1495461&view=auto
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleFunction.java
(added)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleFunction.java Fri
Jun 21 14:53:45 2013
@@ -0,0 +1,43 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.math;
+
+/**
+ * A double double function takes two arguments. A vector or matrix can apply
+ * the double function to each element.
+ *
+ */
+public abstract class DoubleFunction extends Function {
+
+ /**
+ * Apply the function to element.
+ *
+ * @param elem The element that the function apply to.
+ * @return The result after applying the function.
+ */
+ public abstract double apply(double value);
+
+ /**
+ * Apply the gradient of the function.
+ *
+ * @param elem
+ * @return
+ */
+ public abstract double applyDerivative(double value);
+
+}
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java?rev=1495461&r1=1495460&r2=1495461&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java
(original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java Fri
Jun 21 14:53:45 2013
@@ -184,4 +184,25 @@ public interface DoubleMatrix {
*/
public DoubleMatrix slice(int rowOffset, int rowMax, int colOffset, int
colMax);
+ /**
+ * Apply a double function f(x) onto each element of the matrix. After
+ * applying, each element of the current matrix will be changed from x to
+ * f(x).
+ *
+ * @param fun The function.
+ * @return The matrix itself, supply for chain operation.
+ */
+ public DoubleMatrix applyToElements(DoubleFunction fun);
+
+ /**
+ * Apply a double double function f(x, y) onto each pair of the current
matrix
+ * elements and given matrix. After applying, each element of the current
+ * matrix will be changed from x to f(x, y).
+ *
+ * @param other The matrix contributing the second argument of the function.
+ * @param fun The function that takes two arguments.
+ * @return The matrix itself, supply for chain operation.
+ */
+ public DoubleMatrix applyToElements(DoubleMatrix other, DoubleDoubleFunction
fun);
+
}
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java?rev=1495461&r1=1495460&r2=1495461&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java
(original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java Fri
Jun 21 14:53:45 2013
@@ -58,7 +58,7 @@ public interface DoubleVector {
* @param value the value at the index of the vector to set.
*/
public void set(int index, double value);
-
+
/**
* Apply a given {@link DoubleVectorFunction} to this vector and return a new
* one.
@@ -66,8 +66,9 @@ public interface DoubleVector {
* @param func the function to apply.
* @return a new vector with the applied function.
*/
+ @Deprecated
public DoubleVector apply(DoubleVectorFunction func);
-
+
/**
* Apply a given {@link DoubleDoubleVectorFunction} to this vector and the
* other given vector.
@@ -76,9 +77,29 @@ public interface DoubleVector {
* @param func the function to apply on this and the other vector.
* @return a new vector with the result of the function of the two vectors.
*/
+ @Deprecated
public DoubleVector apply(DoubleVector other, DoubleDoubleVectorFunction
func);
/**
+ * Apply a given {@link DoubleVectorFunction} to this vector and return a new
+ * one.
+ *
+ * @param func the function to apply.
+ * @return a new vector with the applied function.
+ */
+ public DoubleVector applyToElements(DoubleFunction func);
+
+ /**
+ * Apply a given {@link DoubleDoubleVectorFunction} to this vector and the
+ * other given vector.
+ *
+ * @param other the other vector.
+ * @param func the function to apply on this and the other vector.
+ * @return a new vector with the result of the function of the two vectors.
+ */
+ public DoubleVector applyToElements(DoubleVector other, DoubleDoubleFunction
func);
+
+ /**
* Adds the given {@link DoubleVector} to this vector.
*
* @param v the other vector.
Modified:
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleVectorFunction.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleVectorFunction.java?rev=1495461&r1=1495460&r2=1495461&view=diff
==============================================================================
---
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleVectorFunction.java
(original)
+++
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleVectorFunction.java
Fri Jun 21 14:53:45 2013
@@ -20,7 +20,10 @@ package org.apache.hama.ml.math;
/**
* A function that can be applied to a double vector via {@link DoubleVector}
* #apply({@link DoubleVectorFunction} f);
+ *
+ * This class will be replaced by {@link DoubleFunction}
*/
+@Deprecated
public interface DoubleVectorFunction {
/**
Added: hama/trunk/ml/src/main/java/org/apache/hama/ml/math/Function.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/math/Function.java?rev=1495461&view=auto
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/math/Function.java (added)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/math/Function.java Fri Jun
21 14:53:45 2013
@@ -0,0 +1,33 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.math;
+
+/**
+ * A generic function.
+ *
+ */
+public abstract class Function {
+ /**
+ * Get the name of the function.
+ *
+ * @return The name of the function.
+ */
+ final public String getFunctionName() {
+ return this.getClass().getSimpleName();
+ }
+}
Added: hama/trunk/ml/src/main/java/org/apache/hama/ml/math/FunctionFactory.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/math/FunctionFactory.java?rev=1495461&view=auto
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/math/FunctionFactory.java
(added)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/math/FunctionFactory.java
Fri Jun 21 14:53:45 2013
@@ -0,0 +1,61 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.math;
+
+/**
+ * Factory to create the functions.
+ *
+ */
+public class FunctionFactory {
+
+ /**
+ * Create a double function with specified name.
+ *
+ * @param functionName
+ * @return
+ */
+ public static DoubleFunction createDoubleFunction(String functionName) {
+ if (functionName.equals(Sigmoid.class.getSimpleName())) {
+ return new Sigmoid();
+ } else if (functionName.equals(Tanh.class.getSimpleName())) {
+ return new Tanh();
+ }
+
+ throw new IllegalArgumentException(String.format(
+ "No double function with name '%s' exists.", functionName));
+ }
+
+ /**
+ * Create a double double function with specified name.
+ *
+ * @param functionName
+ * @return
+ */
+ public static DoubleDoubleFunction createDoubleDoubleFunction(
+ String functionName) {
+ if (functionName.equals(SquaredError.class.getSimpleName())) {
+ return new SquaredError();
+ } else if (functionName.equals(CrossEntropy.class.getSimpleName())) {
+ return new CrossEntropy();
+ }
+
+ throw new IllegalArgumentException(String.format(
+ "No double double function with name '%s' exists.", functionName));
+ }
+
+}
Added: hama/trunk/ml/src/main/java/org/apache/hama/ml/math/Sigmoid.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/math/Sigmoid.java?rev=1495461&view=auto
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/math/Sigmoid.java (added)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/math/Sigmoid.java Fri Jun 21
14:53:45 2013
@@ -0,0 +1,39 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.math;
+
+/**
+ * The Sigmoid function
+ *
+ * <pre>
+ * f(x) = 1 / (1 + e^{-x})
+ * </pre>
+ */
+public class Sigmoid extends DoubleFunction {
+
+ @Override
+ public double apply(double value) {
+ return 1.0 / (1 + Math.exp(-value));
+ }
+
+ @Override
+ public double applyDerivative(double value) {
+ return value * (1 - value);
+ }
+
+}
Added: hama/trunk/ml/src/main/java/org/apache/hama/ml/math/SquaredError.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/math/SquaredError.java?rev=1495461&view=auto
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/math/SquaredError.java
(added)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/math/SquaredError.java Fri
Jun 21 14:53:45 2013
@@ -0,0 +1,47 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.math;
+
+/**
+ * Square error cost function.
+ *
+ * <pre>
+ * cost(t, y) = 0.5 * (t - y) ˆ 2
+ * </pre>
+ */
+public class SquaredError extends DoubleDoubleFunction {
+
+ @Override
+ /**
+ * {@inheritDoc}
+ */
+ public double apply(double target, double actual) {
+ double diff = target - actual;
+ return 0.5 * diff * diff;
+ }
+
+ @Override
+ /**
+ * {@inheritDoc}
+ */
+ public double applyDerivative(double target, double actual) {
+ // return target - actual;
+ return actual - target;
+ }
+
+}
Added: hama/trunk/ml/src/main/java/org/apache/hama/ml/math/Tanh.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/math/Tanh.java?rev=1495461&view=auto
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/math/Tanh.java (added)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/math/Tanh.java Fri Jun 21
14:53:45 2013
@@ -0,0 +1,36 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.math;
+
+/**
+ * Tanh function.
+ *
+ */
+public class Tanh extends DoubleFunction {
+
+ @Override
+ public double apply(double value) {
+ return Math.tanh(value);
+ }
+
+ @Override
+ public double applyDerivative(double value) {
+ return 1 - value * value;
+ }
+
+}
Modified:
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java?rev=1495461&r1=1495460&r2=1495461&view=diff
==============================================================================
---
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java
(original)
+++
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java
Fri Jun 21 14:53:45 2013
@@ -21,7 +21,10 @@ import java.io.IOException;
import java.util.Map;
import org.apache.hadoop.fs.Path;
+import org.apache.hama.ml.math.DoubleDoubleFunction;
+import org.apache.hama.ml.math.DoubleFunction;
import org.apache.hama.ml.math.DoubleVector;
+import org.apache.hama.ml.math.FunctionFactory;
/**
* PerceptronBase defines the common behavior of all the concrete perceptrons.
@@ -43,8 +46,8 @@ public abstract class MultiLayerPerceptr
protected String costFunctionName;
protected int[] layerSizeArray;
- protected CostFunction costFunction;
- protected SquashingFunction squashingFunction;
+ protected DoubleDoubleFunction costFunction;
+ protected DoubleFunction squashingFunction;
/**
* Initialize the MLP.
@@ -83,10 +86,10 @@ public abstract class MultiLayerPerceptr
this.layerSizeArray = layerSizeArray;
this.numberOfLayers = this.layerSizeArray.length;
- this.costFunction = CostFunctionFactory
- .getCostFunction(this.costFunctionName);
- this.squashingFunction = SquashingFunctionFactory
- .getSquashingFunction(this.squashingFunctionName);
+ this.costFunction = FunctionFactory
+ .createDoubleDoubleFunction(this.costFunctionName);
+ this.squashingFunction = FunctionFactory
+ .createDoubleFunction(this.squashingFunctionName);
}
/**
Modified:
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMultiLayerPerceptron.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMultiLayerPerceptron.java?rev=1495461&r1=1495460&r2=1495461&view=diff
==============================================================================
---
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMultiLayerPerceptron.java
(original)
+++
hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMultiLayerPerceptron.java
Fri Jun 21 14:53:45 2013
@@ -41,7 +41,9 @@ import org.apache.hama.HamaConfiguration
import org.apache.hama.bsp.BSPJob;
import org.apache.hama.ml.math.DenseDoubleMatrix;
import org.apache.hama.ml.math.DenseDoubleVector;
+import org.apache.hama.ml.math.DoubleFunction;
import org.apache.hama.ml.math.DoubleVector;
+import org.apache.hama.ml.math.FunctionFactory;
import org.apache.hama.ml.writable.MatrixWritable;
import org.apache.hama.ml.writable.VectorWritable;
import org.mortbay.log.Log;
@@ -102,18 +104,34 @@ public final class SmallMultiLayerPercep
private void initializeWeightMatrix() {
this.weightMatrice = new DenseDoubleMatrix[this.numberOfLayers - 1];
// each layer contains one bias neuron
- Random rnd = new Random();
for (int i = 0; i < this.numberOfLayers - 1; ++i) {
// add weights for bias
this.weightMatrice[i] = new DenseDoubleMatrix(this.layerSizeArray[i] + 1,
this.layerSizeArray[i + 1]);
- int rowCount = this.weightMatrice[i].getRowCount();
- int colCount = this.weightMatrice[i].getColumnCount();
- for (int row = 0; row < rowCount; ++row) {
- for (int col = 0; col < colCount; ++col) {
- this.weightMatrice[i].set(row, col, rnd.nextDouble() - 0.5);
+
+ this.weightMatrice[i].applyToElements(new DoubleFunction() {
+
+ private Random rnd = new Random();
+
+ @Override
+ public double apply(double value) {
+ return rnd.nextDouble() - 0.5;
}
- }
+
+ @Override
+ public double applyDerivative(double value) {
+ throw new UnsupportedOperationException("Not supported");
+ }
+
+ });
+
+// int rowCount = this.weightMatrice[i].getRowCount();
+// int colCount = this.weightMatrice[i].getColumnCount();
+// for (int row = 0; row < rowCount; ++row) {
+// for (int col = 0; col < colCount; ++col) {
+// this.weightMatrice[i].set(row, col, rnd.nextDouble() - 0.5);
+// }
+// }
}
}
@@ -199,8 +217,7 @@ public final class SmallMultiLayerPercep
prevNeuronIdx, neuronIdx) * intermediateResult[prevNeuronIdx];
}
// calculate via squashing function
- results[neuronIdx + offset] = this.squashingFunction.calculate(0,
- results[neuronIdx + offset]);
+ results[neuronIdx + offset] =
this.squashingFunction.apply(results[neuronIdx + offset]);
}
return results;
@@ -243,7 +260,7 @@ public final class SmallMultiLayerPercep
DenseDoubleMatrix prevWeightUpdateMatrix =
this.prevWeightUpdateMatrices[this.prevWeightUpdateMatrices.length - 1];
for (int j = 0; j < delta.length; ++j) {
- delta[j] = this.costFunction.calculateDerivative(trainingLabels[j],
+ delta[j] = this.costFunction.applyDerivative(trainingLabels[j],
outputLayerOutput[j]);
// add regularization term
if (this.regularization != 0.0) {
@@ -257,7 +274,7 @@ public final class SmallMultiLayerPercep
}
delta[j] *= this.squashingFunction
- .calculateDerivative(outputLayerOutput[j]);
+ .applyDerivative(outputLayerOutput[j]);
// calculate the weight update matrix between the last hidden layer and
// the output layer
@@ -307,7 +324,7 @@ public final class SmallMultiLayerPercep
delta[j] += weight * nextLayerDelta[k];
}
delta[j] *= this.squashingFunction
- .calculateDerivative(curLayerOutput[j + 1]);
+ .applyDerivative(curLayerOutput[j + 1]);
// calculate the weight update matrix between the previous layer and the
// current layer
@@ -395,10 +412,10 @@ public final class SmallMultiLayerPercep
for (int i = 0; i < numberOfLayers - 1; ++i) {
this.weightMatrice[i] = (DenseDoubleMatrix) MatrixWritable.read(input);
}
- this.squashingFunction = SquashingFunctionFactory
- .getSquashingFunction(this.squashingFunctionName);
- this.costFunction = CostFunctionFactory
- .getCostFunction(this.costFunctionName);
+ this.squashingFunction = FunctionFactory
+ .createDoubleFunction(this.squashingFunctionName);
+ this.costFunction = FunctionFactory
+ .createDoubleDoubleFunction(this.costFunctionName);
}
@Override
Added:
hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java?rev=1495461&view=auto
==============================================================================
---
hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java
(added)
+++
hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java
Fri Jun 21 14:53:45 2013
@@ -0,0 +1,86 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.math;
+
+import static org.junit.Assert.assertArrayEquals;
+
+import org.junit.Test;
+
+/**
+ * Test case for {@link DenseDoubleMatrix}
+ *
+ */
+public class TestDenseDoubleMatrix {
+
+ @Test
+ public void testDoubleFunction() {
+ double[][] values = new double[][] { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 }
};
+
+ double[][] result = new double[][] { { 2, 3, 4 }, { 5, 6, 7 }, { 8, 9, 10
} };
+
+ DenseDoubleMatrix mat = new DenseDoubleMatrix(values);
+ mat.applyToElements(new DoubleFunction() {
+
+ @Override
+ public double apply(double value) {
+ return value + 1;
+ }
+
+ @Override
+ public double applyDerivative(double value) {
+ throw new UnsupportedOperationException();
+ }
+
+ });
+
+ double[][] actual = mat.getValues();
+ for (int i = 0; i < actual.length; ++i) {
+ assertArrayEquals(result[i], actual[i], 0.0001);
+ }
+ }
+
+ @Test
+ public void testDoubleDoubleFunction() {
+ double[][] values1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9
} };
+ double[][] values2 = new double[][] { { 2, 3, 4 }, { 5, 6, 7 }, { 8, 9, 10
} };
+ double[][] result = new double[][] { {3, 5, 7}, {9, 11, 13}, {15, 17, 19}};
+
+ DenseDoubleMatrix mat1 = new DenseDoubleMatrix(values1);
+ DenseDoubleMatrix mat2 = new DenseDoubleMatrix(values2);
+
+ mat1.applyToElements(mat2, new DoubleDoubleFunction() {
+
+ @Override
+ public double apply(double x1, double x2) {
+ return x1 + x2;
+ }
+
+ @Override
+ public double applyDerivative(double x1, double x2) {
+ throw new UnsupportedOperationException();
+ }
+
+ });
+
+ double[][] actual = mat1.getValues();
+ for (int i = 0; i < actual.length; ++i) {
+ assertArrayEquals(result[i], actual[i], 0.0001);
+ }
+ }
+
+}
Added:
hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java?rev=1495461&view=auto
==============================================================================
---
hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java
(added)
+++
hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java
Fri Jun 21 14:53:45 2013
@@ -0,0 +1,80 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.math;
+
+import static org.junit.Assert.assertArrayEquals;
+
+import org.junit.Test;
+
+/**
+ * Testcase for {@link DenseDoubleVector}
+ *
+ */
+public class TestDenseDoubleVector {
+
+ @Test
+ public void testApplyDoubleFunction() {
+ double[] values = new double[] {1, 2, 3, 4, 5};
+ double[] result = new double[] {2, 3, 4, 5, 6};
+
+ DoubleVector vec1 = new DenseDoubleVector(values);
+
+ vec1.applyToElements(new DoubleFunction() {
+
+ @Override
+ public double apply(double value) {
+ return value + 1;
+ }
+
+ @Override
+ public double applyDerivative(double value) {
+ throw new UnsupportedOperationException("Not supported.");
+ }
+
+ });
+
+ assertArrayEquals(result, vec1.toArray(), 0.0001);
+ }
+
+ @Test
+ public void testApplyDoubleDoubleFunction() {
+ double[] values1 = new double[] {1, 2, 3, 4, 5, 6};
+ double[] values2 = new double[] {7, 8, 9, 10, 11, 12};
+ double[] result = new double[] {8, 10, 12, 14, 16, 18};
+
+ DoubleVector vec1 = new DenseDoubleVector(values1);
+ DoubleVector vec2 = new DenseDoubleVector(values2);
+
+ vec1.applyToElements(vec2, new DoubleDoubleFunction() {
+
+ @Override
+ public double apply(double x1, double x2) {
+ return x1 + x2;
+ }
+
+ @Override
+ public double applyDerivative(double x1, double x2) {
+ throw new UnsupportedOperationException("Not supported");
+ }
+
+ });
+
+ assertArrayEquals(result, vec1.toArray(), 0.0001);
+
+ }
+}
Added:
hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestFunctionFactory.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestFunctionFactory.java?rev=1495461&view=auto
==============================================================================
---
hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestFunctionFactory.java
(added)
+++
hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestFunctionFactory.java
Fri Jun 21 14:53:45 2013
@@ -0,0 +1,71 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.math;
+
+import static org.junit.Assert.assertEquals;
+
+import org.junit.Test;
+
+/**
+ * Test case for {@link FunctionFactory}
+ *
+ */
+public class TestFunctionFactory {
+
+ @Test
+ public void testCreateDoubleFunction() {
+ double input = 0.8;
+
+ String sigmoidName = "Sigmoid";
+ DoubleFunction sigmoidFunction = FunctionFactory
+ .createDoubleFunction(sigmoidName);
+ assertEquals(sigmoidName, sigmoidFunction.getFunctionName());
+
+ double sigmoidExcepted = 0.68997448;
+ assertEquals(sigmoidExcepted, sigmoidFunction.apply(input), 0.000001);
+
+ String tanhName = "Tanh";
+ DoubleFunction tanhFunction =
FunctionFactory.createDoubleFunction(tanhName);
+ assertEquals(tanhName, tanhFunction.getFunctionName());
+
+ double tanhExpected = 0.66403677;
+ assertEquals(tanhExpected, tanhFunction.apply(input), 0.00001);
+ }
+
+ @Test
+ public void testCreateDoubleDoubleFunction() {
+ double target = 0.5;
+ double output = 0.8;
+
+ String squaredErrorName = "SquaredError";
+ DoubleDoubleFunction squaredErrorFunction =
FunctionFactory.createDoubleDoubleFunction(squaredErrorName);
+ assertEquals(squaredErrorName, squaredErrorFunction.getFunctionName());
+
+ double squaredErrorExpected = 0.045;
+
+ assertEquals(squaredErrorExpected, squaredErrorFunction.apply(target,
output), 0.000001);
+
+ String crossEntropyName = "CrossEntropy";
+ DoubleDoubleFunction crossEntropyFunction =
FunctionFactory.createDoubleDoubleFunction(crossEntropyName);
+ assertEquals(crossEntropyName, crossEntropyFunction.getFunctionName());
+
+ double crossEntropyExpected = 0.91629;
+ assertEquals(crossEntropyExpected, crossEntropyFunction.apply(target,
output), 0.000001);
+ }
+
+}
Modified:
hama/trunk/ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java?rev=1495461&r1=1495460&r2=1495461&view=diff
==============================================================================
---
hama/trunk/ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java
(original)
+++
hama/trunk/ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java
Fri Jun 21 14:53:45 2013
@@ -1,5 +1,5 @@
/**
- * Licensed to the Apache Software Foundation (ASF) under one
+c * Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file