Author: yxjiang
Date: Fri Aug 16 14:40:36 2013
New Revision: 1514736
URL: http://svn.apache.org/r1514736
Log:
HAMA-796: Add Vector multiply Matrix for DoubleVector as well as
DenseDoubleVector
Modified:
hama/trunk/CHANGES.txt
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java
hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java
Modified: hama/trunk/CHANGES.txt
URL:
http://svn.apache.org/viewvc/hama/trunk/CHANGES.txt?rev=1514736&r1=1514735&r2=1514736&view=diff
==============================================================================
--- hama/trunk/CHANGES.txt (original)
+++ hama/trunk/CHANGES.txt Fri Aug 16 14:40:36 2013
@@ -20,6 +20,7 @@ Release 0.6.3 (unreleased changes)
IMPROVEMENTS
+ HAMA-796: Add Vector multiply Matrix for DoubleVector as well as
DenseDoubleVector. (Yexi Jiang)
HAMA-770: Use a unified model to represent linear regression, logistic
regression, MLP, autoencoder, and deepNets (Yexi Jiang)
HAMA-671: Clean up Maven build scripts (edwardyoon)
HAMA-765: Add apply method to Vector/Matrix (Yexi Jiang)
Modified:
hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java?rev=1514736&r1=1514735&r2=1514736&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java
(original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java
Fri Aug 16 14:40:36 2013
@@ -245,6 +245,22 @@ public final class DenseDoubleVector imp
return v;
}
+ @Override
+ public DoubleVector multiply(DoubleMatrix matrix) {
+ Preconditions.checkArgument(this.vector.length == matrix.getRowCount(),
+ "Dimension mismatch when multiply a vector to a matrix.");
+ return this.multiplyUnsafe(matrix);
+ }
+
+ @Override
+ public DoubleVector multiplyUnsafe(DoubleMatrix matrix) {
+ DoubleVector vec = new DenseDoubleVector(matrix.getColumnCount());
+ for (int i = 0; i < vec.getDimension(); ++i) {
+ vec.set(i, this.multiplyUnsafe(matrix.getColumnVector(i)).sum());
+ }
+ return vec;
+ }
+
/*
* (non-Javadoc)
* @see de.jungblut.math.DoubleVector#divide(double)
@@ -356,12 +372,12 @@ public final class DenseDoubleVector imp
public DoubleVector slice(int length) {
return slice(0, length - 1);
}
-
+
@Override
public DoubleVector sliceUnsafe(int length) {
return sliceUnsafe(0, length - 1);
}
-
+
/*
* (non-Javadoc)
* @see de.jungblut.math.DoubleVector#slice(int, int)
@@ -373,7 +389,7 @@ public final class DenseDoubleVector imp
return sliceUnsafe(start, end);
}
-
+
/**
* {@inheritDoc}
*/
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java?rev=1514736&r1=1514735&r2=1514736&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java
(original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java Fri
Aug 16 14:40:36 2013
@@ -183,6 +183,23 @@ public interface DoubleVector {
public DoubleVector multiply(DoubleVector vector);
/**
+ * Validates the input and multiplies the given {@link DoubleMatrix} with
this
+ * vector.
+ *
+ * @param matrix
+ * @return
+ */
+ public DoubleVector multiply(DoubleMatrix matrix);
+
+ /**
+ * Multiplies the given {@link DoubleMatrix} with this vector.
+ *
+ * @param matrix
+ * @return
+ */
+ public DoubleVector multiplyUnsafe(DoubleMatrix matrix);
+
+ /**
* Divides this vector by the given scalar. (= vector/scalar).
*
* @param scalar the given scalar.
@@ -243,13 +260,14 @@ public interface DoubleVector {
public double dot(DoubleVector vector);
/**
- * Validates the input and slices this vector from index 0 to the given
length.
+ * Validates the input and slices this vector from index 0 to the given
+ * length.
*
* @param length must be > 0 and smaller than the dimension of the vector.
* @return a new vector that is only length long.
*/
public DoubleVector slice(int length);
-
+
/**
* Slices this vector from index 0 to the given length.
*
Modified:
hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java?rev=1514736&r1=1514735&r2=1514736&view=diff
==============================================================================
---
hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java
(original)
+++
hama/trunk/ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java
Fri Aug 16 14:40:36 2013
@@ -185,4 +185,24 @@ public class TestDenseDoubleVector {
DoubleVector vec = new DenseDoubleVector(arr1);
vec.slice(4, 3);
}
+
+ @Test
+ public void testVectorMultiplyMatrix() {
+ DoubleVector vec = new DenseDoubleVector(new double[]{1, 2, 3});
+ DoubleMatrix mat = new DenseDoubleMatrix(new double[][] {
+ {1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}
+ });
+ double[] expectedRes = new double[] {38, 44, 50, 56};
+
+ assertArrayEquals(expectedRes, vec.multiply(mat).toArray(), 0.000001);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testVectorMultiplyMatrixAbnormal() {
+ DoubleVector vec = new DenseDoubleVector(new double[]{1, 2, 3});
+ DoubleMatrix mat = new DenseDoubleMatrix(new double[][] {
+ {1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}
+ });
+ vec.multiply(mat);
+ }
}