jessicapriebe commented on code in PR #1983:
URL: https://github.com/apache/systemds/pull/1983#discussion_r1460442273


##########
src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixFourier.java:
##########
@@ -0,0 +1,236 @@
+package org.apache.sysds.runtime.matrix.data;
+
+import org.apache.commons.math3.util.FastMath;
+import java.util.Arrays;
+
+public class LibMatrixFourier {
+
+    /**
+     * Function to perform Fast Fourier Transformation
+     */
+
+    public static MatrixBlock[] fft(MatrixBlock re, MatrixBlock im){
+
+        int rows = re.getNumRows();
+        int cols = re.getNumColumns();
+
+        double[][][] in = new double[2][rows][cols];
+        in[0] = convertToArray(re);
+        in[1] = convertToArray(im);
+
+        double[][][] res = fft(in, false);
+
+        return convertToMatrixBlocks(res);
+    }
+
+    public static MatrixBlock[] ifft(MatrixBlock re, MatrixBlock im){
+
+        int rows = re.getNumRows();
+        int cols = re.getNumColumns();
+
+        double[][][] in = new double[2][rows][cols];
+        in[0] = convertToArray(re);
+        in[1] = convertToArray(im);
+
+        double[][][] res = fft(in, true);
+
+        return convertToMatrixBlocks(res);
+    }
+
+    public static double[][][] fft(double[][][] in, boolean calcInv){
+
+        int rows = in[0].length;
+        int cols = in[0][0].length;
+
+        double[][][] res = new double[2][rows][cols];
+
+        for(int i = 0; i < rows; i++){
+            // use fft or ifft on each row
+            double[][] res_row = calcInv? ifft_one_dim(get_complex_row(in, i)) 
: fft_one_dim(get_complex_row(in, i));
+
+            // set res row
+            for (int j = 0; j < cols; j++){
+                for( int k = 0; k < 2; k++){
+                    res[k][i][j] = res_row[k][j];
+                }
+            }
+        }
+
+        if(rows == 1) return res;
+
+        for(int j = 0; j < cols; j++){
+            // use fft on each col
+            double[][] res_col = calcInv? ifft_one_dim(get_complex_col(res, 
j)) : fft_one_dim(get_complex_col(res, j));
+
+            // set res col
+            for (int i = 0; i < rows; i++){
+                for( int k = 0; k < 2; k++){
+                    res[k][i][j] = res_col[k][i];
+                }
+            }
+        }
+
+        return res;
+    }
+
+    public static double[][] fft_one_dim(double[][] in){
+        // 1st row real part, 2nd row imaginary part
+        if(in == null || in.length != 2 || in[0].length != in[1].length) throw 
new RuntimeException("in false dimensions");
+
+        int cols = in[0].length;
+        if(cols == 1) return in;
+
+        double angle = -2*FastMath.PI/cols;
+
+        // split values depending on index
+        double[][] even = new double[2][cols/2];
+        double[][] odd = new double[2][cols/2];
+
+        for(int i = 0; i < 2; i++){
+            for (int j = 0; j < cols/2; j++){
+                even[i][j] = in[i][j*2];
+                odd[i][j] = in[i][j*2+1];
+            }
+        }
+        double[][] res_even = fft_one_dim(even);
+        double[][] res_odd = fft_one_dim(odd);
+
+        double[][] res = new double[2][cols];
+
+        for(int j=0; j < cols/2; j++){
+            double[] omega_pow = new double[]{FastMath.cos(j*angle), 
FastMath.sin(j*angle)};
+
+            // m = omega * res_odd[j]
+            double[] m = new double[]{
+                    omega_pow[0] * res_odd[0][j] - omega_pow[1] * 
res_odd[1][j],
+                    omega_pow[0] * res_odd[1][j] + omega_pow[1] * 
res_odd[0][j]};
+
+            // res[j] = res_even + m;
+            // res[j+cols/2] = res_even - m;
+            for(int i = 0; i < 2; i++){
+                res[i][j] = res_even[i][j] + m[i];
+                res[i][j+cols/2] = res_even[i][j] - m[i];
+            }
+        }
+
+        return res;
+
+    }
+
+    public static double[][] ifft_one_dim(double[][] in) {
+
+        // cols[0] is real part, cols[1] is imaginary part
+        int cols = in[0].length;
+
+        // conjugate input
+        in[1] = Arrays.stream(in[1]).map(i -> -i).toArray();
+
+        // apply fft
+        double[][] res = fft_one_dim(in);
+
+        // conjugate and scale result
+        res[0] = Arrays.stream(res[0]).map(i -> i/cols).toArray();
+        res[1] = Arrays.stream(res[1]).map(i -> -i/cols).toArray();
+
+        return res;
+    }
+
+    private static MatrixBlock[] convertToMatrixBlocks(double[][][] in){
+
+        int cols = in[0][0].length;
+        int rows = in[0].length;
+
+        double[] flattened_re = 
Arrays.stream(in[0]).flatMapToDouble(Arrays::stream).toArray();

Review Comment:
   resolved



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: dev-unsubscr...@systemds.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to