jessicapriebe commented on code in PR #1983: URL: https://github.com/apache/systemds/pull/1983#discussion_r1460442035
########## src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixFourier.java: ########## @@ -0,0 +1,236 @@ +package org.apache.sysds.runtime.matrix.data; + +import org.apache.commons.math3.util.FastMath; +import java.util.Arrays; + +public class LibMatrixFourier { + + /** + * Function to perform Fast Fourier Transformation + */ + + public static MatrixBlock[] fft(MatrixBlock re, MatrixBlock im){ + + int rows = re.getNumRows(); + int cols = re.getNumColumns(); + + double[][][] in = new double[2][rows][cols]; + in[0] = convertToArray(re); + in[1] = convertToArray(im); Review Comment: resolved ########## src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixFourier.java: ########## @@ -0,0 +1,236 @@ +package org.apache.sysds.runtime.matrix.data; + +import org.apache.commons.math3.util.FastMath; +import java.util.Arrays; + +public class LibMatrixFourier { + + /** + * Function to perform Fast Fourier Transformation + */ + + public static MatrixBlock[] fft(MatrixBlock re, MatrixBlock im){ + + int rows = re.getNumRows(); + int cols = re.getNumColumns(); + + double[][][] in = new double[2][rows][cols]; + in[0] = convertToArray(re); + in[1] = convertToArray(im); + + double[][][] res = fft(in, false); + + return convertToMatrixBlocks(res); + } + + public static MatrixBlock[] ifft(MatrixBlock re, MatrixBlock im){ + + int rows = re.getNumRows(); + int cols = re.getNumColumns(); + + double[][][] in = new double[2][rows][cols]; + in[0] = convertToArray(re); + in[1] = convertToArray(im); + + double[][][] res = fft(in, true); + + return convertToMatrixBlocks(res); + } + + public static double[][][] fft(double[][][] in, boolean calcInv){ + + int rows = in[0].length; + int cols = in[0][0].length; + + double[][][] res = new double[2][rows][cols]; + + for(int i = 0; i < rows; i++){ + // use fft or ifft on each row + double[][] res_row = calcInv? ifft_one_dim(get_complex_row(in, i)) : fft_one_dim(get_complex_row(in, i)); + + // set res row + for (int j = 0; j < cols; j++){ + for( int k = 0; k < 2; k++){ + res[k][i][j] = res_row[k][j]; + } + } + } + + if(rows == 1) return res; + + for(int j = 0; j < cols; j++){ + // use fft on each col + double[][] res_col = calcInv? ifft_one_dim(get_complex_col(res, j)) : fft_one_dim(get_complex_col(res, j)); + + // set res col + for (int i = 0; i < rows; i++){ + for( int k = 0; k < 2; k++){ + res[k][i][j] = res_col[k][i]; + } + } + } + + return res; + } + + public static double[][] fft_one_dim(double[][] in){ + // 1st row real part, 2nd row imaginary part + if(in == null || in.length != 2 || in[0].length != in[1].length) throw new RuntimeException("in false dimensions"); + + int cols = in[0].length; + if(cols == 1) return in; + + double angle = -2*FastMath.PI/cols; + + // split values depending on index + double[][] even = new double[2][cols/2]; + double[][] odd = new double[2][cols/2]; + + for(int i = 0; i < 2; i++){ + for (int j = 0; j < cols/2; j++){ + even[i][j] = in[i][j*2]; + odd[i][j] = in[i][j*2+1]; + } + } + double[][] res_even = fft_one_dim(even); + double[][] res_odd = fft_one_dim(odd); + + double[][] res = new double[2][cols]; Review Comment: resolved -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: dev-unsubscr...@systemds.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org