Author: tommaso
Date: Mon Oct 8 15:13:27 2012
New Revision: 1395604
URL: http://svn.apache.org/viewvc?rev=1395604&view=rev
Log:
[HAMA-651] - added some model classes to abstract cost function and hypothesis
from the gradient descent bsp
Added:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/CostFunction.java
(with props)
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/HypothesisFunction.java
(with props)
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java
(with props)
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
(with props)
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/RegressionModel.java
(with props)
Modified:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
Added:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/CostFunction.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/CostFunction.java?rev=1395604&view=auto
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/CostFunction.java
(added)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/CostFunction.java
Mon Oct 8 15:13:27 2012
@@ -0,0 +1,38 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.regression;
+
+import org.apache.hama.ml.math.DoubleVector;
+
+/**
+ * An optimization (minimization) problem's cost function
+ */
+public interface CostFunction {
+
+ /**
+ * Calculates the cost function for a given item (input x, output y), a model
+ * defined by the hypothesis parametrized by the vector theta
+ * @param x the input vector
+ * @param y the learned output for x
+ * @param theta the parameters vector theta
+ * @param hypothesis the hypothesis function to model the problem
+ * @return the calculated cost for input x and output y
+ */
+ public double calculateCostForItem(DoubleVector x, double y, DoubleVector
theta, HypothesisFunction hypothesis);
+
+}
Propchange:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/CostFunction.java
------------------------------------------------------------------------------
svn:eol-style = native
Modified:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java?rev=1395604&r1=1395603&r2=1395604&view=diff
==============================================================================
---
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
(original)
+++
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
Mon Oct 8 15:13:27 2012
@@ -34,7 +34,7 @@ import java.io.IOException;
* A gradient descent (see
<code>http://en.wikipedia.org/wiki/Gradient_descent</code>) BSP based abstract
implementation.
* Each extending class should implement the #applyHypothesis(DoubleVector
theta, DoubleVector x) method for a specific
*/
-public abstract class GradientDescentBSP extends BSP<VectorWritable,
DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> {
+public class GradientDescentBSP extends BSP<VectorWritable, DoubleWritable,
VectorWritable, DoubleWritable, VectorWritable> {
private static final Logger log =
LoggerFactory.getLogger(GradientDescentBSP.class);
static final String INITIAL_THETA_VALUES = "initial.theta.values";
@@ -45,6 +45,7 @@ public abstract class GradientDescentBSP
private double cost;
private double threshold;
private float alpha;
+ private RegressionModel regressionModel;
@Override
public void setup(BSPPeer<VectorWritable, DoubleWritable, VectorWritable,
DoubleWritable, VectorWritable> peer) throws IOException, SyncException,
InterruptedException {
@@ -52,6 +53,11 @@ public abstract class GradientDescentBSP
cost = Integer.MAX_VALUE;
threshold = peer.getConfiguration().getFloat("threashold", 0.01f);
alpha = peer.getConfiguration().getFloat(ALPHA, 0.3f);
+ try {
+ regressionModel = ((Class<? extends
RegressionModel>)peer.getConfiguration().getClass("regression.model",
LinearRegressionModel.class)).newInstance();
+ } catch (Exception e) {
+ throw new IOException(e);
+ }
}
@Override
@@ -74,7 +80,7 @@ public abstract class GradientDescentBSP
// calculate cost for given input
double y = kvp.getValue().get();
DoubleVector x = kvp.getKey().getVector();
- double costForX = calculateCostForItem(y, x, theta);
+ double costForX = regressionModel.calculateCostForItem(x, y, theta);
// adds to local cost
localCost += costForX;
@@ -127,7 +133,7 @@ public abstract class GradientDescentBSP
while ((kvp = peer.readNext()) != null) {
DoubleVector x = kvp.getKey().getVector();
double y = kvp.getValue().get();
- double difference = applyHypothesis(theta, x) - y;
+ double difference = regressionModel.applyHypothesis(theta, x) - y;
for (int j = 0; j < theta.getLength(); j++) {
thetaDelta[j] += difference * x.get(j);
}
@@ -169,25 +175,6 @@ public abstract class GradientDescentBSP
}
- /**
- * Calculates the cost function for a given item (input x, output y)
- * @param y the learned output for x
- * @param x the input vector
- * @param theta the parameters vector theta
- * @return the calculated cost for input x and output y
- */
- protected abstract double calculateCostForItem(double y, DoubleVector x,
DoubleVector theta);
-
- /**
- * Applies the applyHypothesis given a set of parameters theta to a given
input x
- *
- * @param theta the parameters vector
- * @param x the input
- * @return a <code>double</code> number
- */
- public abstract double applyHypothesis(DoubleVector theta, DoubleVector x);
-
-
public void getTheta(BSPPeer<VectorWritable, DoubleWritable, VectorWritable,
DoubleWritable, VectorWritable> peer) throws IOException, SyncException,
InterruptedException {
if (master && theta == null) {
int size = getXSize(peer);
Added:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/HypothesisFunction.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/HypothesisFunction.java?rev=1395604&view=auto
==============================================================================
---
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/HypothesisFunction.java
(added)
+++
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/HypothesisFunction.java
Mon Oct 8 15:13:27 2012
@@ -0,0 +1,36 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.regression;
+
+import org.apache.hama.ml.math.DoubleVector;
+
+/**
+ * The mathematical model chosen for a specific learning problem
+ */
+public interface HypothesisFunction {
+
+ /**
+ * Applies the applyHypothesis given a set of parameters theta to a given
input x
+ *
+ * @param theta the parameters vector
+ * @param x the input
+ * @return a <code>double</code> number
+ */
+ public double applyHypothesis(DoubleVector theta, DoubleVector x);
+
+}
Propchange:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/HypothesisFunction.java
------------------------------------------------------------------------------
svn:eol-style = native
Added:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java?rev=1395604&view=auto
==============================================================================
---
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java
(added)
+++
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java
Mon Oct 8 15:13:27 2012
@@ -0,0 +1,47 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.regression;
+
+import org.apache.hama.ml.math.DoubleVector;
+
+/**
+ * A {@link RegressionModel} for linear regression
+ */
+public class LinearRegressionModel implements RegressionModel {
+
+ private final CostFunction costFunction;
+
+ public LinearRegressionModel() {
+ costFunction = new CostFunction() {
+ @Override
+ public double calculateCostForItem(DoubleVector x, double y,
DoubleVector theta, HypothesisFunction hypothesis) {
+ return y * Math.pow(applyHypothesis(theta, x) - y, 2) / 2;
+ }
+ };
+ }
+
+ @Override
+ public double applyHypothesis(DoubleVector theta, DoubleVector x) {
+ return theta.dot(x);
+ }
+
+ @Override
+ public double calculateCostForItem(DoubleVector x, double y, DoubleVector
theta) {
+ return costFunction.calculateCostForItem(x, y, theta, this);
+ }
+}
Propchange:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java
------------------------------------------------------------------------------
svn:eol-style = native
Added:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java?rev=1395604&view=auto
==============================================================================
---
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
(added)
+++
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
Mon Oct 8 15:13:27 2012
@@ -0,0 +1,47 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.regression;
+
+import org.apache.hama.ml.math.DoubleVector;
+
+/**
+ * A {@link RegressionModel} for logistic regression
+ */
+public class LogisticRegressionModel implements RegressionModel {
+
+ private final CostFunction costFunction;
+
+ public LogisticRegressionModel() {
+ costFunction = new CostFunction() {
+ @Override
+ public double calculateCostForItem(DoubleVector x, double y,
DoubleVector theta, HypothesisFunction hypothesis) {
+ return -1 * y * Math.log(applyHypothesis(theta, x)) + (1 - y) *
Math.log(1 - applyHypothesis(theta, x));
+ }
+ };
+ }
+
+ @Override
+ public double applyHypothesis(DoubleVector theta, DoubleVector x) {
+ return 1d / (1d + Math.exp(-1 * theta.dot(x)));
+ }
+
+ @Override
+ public double calculateCostForItem(DoubleVector x, double y, DoubleVector
theta) {
+ return costFunction.calculateCostForItem(x, y, theta, this);
+ }
+}
Propchange:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
------------------------------------------------------------------------------
svn:eol-style = native
Added:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/RegressionModel.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/RegressionModel.java?rev=1395604&view=auto
==============================================================================
---
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/RegressionModel.java
(added)
+++
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/RegressionModel.java
Mon Oct 8 15:13:27 2012
@@ -0,0 +1,37 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.regression;
+
+import org.apache.hama.ml.math.DoubleVector;
+
+/**
+ * A cost model for gradient descent based regression
+ */
+public interface RegressionModel extends HypothesisFunction {
+
+ /**
+ * Calculates the cost function for a given item (input x, output y) and
+ * the model's parameters defined by the vector theta
+ * @param x the input vector
+ * @param y the learned output for x
+ * @param theta the parameters vector theta
+ * @return the calculated cost for input x and output y
+ */
+ public double calculateCostForItem(DoubleVector x, double y, DoubleVector
theta);
+
+}
Propchange:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/RegressionModel.java
------------------------------------------------------------------------------
svn:eol-style = native