jessicapriebe commented on code in PR #1983: URL: https://github.com/apache/systemds/pull/1983#discussion_r1460445086
########## src/test/java/org/apache/sysds/test/component/matrix/FourierTestWithFiles.java: ########## @@ -0,0 +1,377 @@ +package org.apache.sysds.test.component.matrix; + +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.junit.Test; + +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.IOException; + +import static org.apache.sysds.runtime.matrix.data.LibMatrixFourier.*; +import static org.junit.Assert.assertTrue; + +public class FourierTestWithFiles { + int progressInterval = 5000; + + // prior to executing the following tests it is necessary to run the Numpy Script in FourierTestData.py + // and add the generated files to the root of the project. + @Test + public void testFftWithNumpyData() throws IOException { + String filename = "fft_data.csv"; // Path to your CSV file + BufferedReader reader = new BufferedReader(new FileReader(filename)); + String line; + int lineNumber = 0; + long totalTime = 0; // Total time for all FFT computations + int numCalculations = 0; // Number of FFT computations + + while ((line = reader.readLine()) != null) { + lineNumber++; + + String[] values = line.split(","); + int n = values.length / 3; + double[][][] input = new double[2][1][n]; + double[][] expected = new double[2][n]; // First row for real, second row for imaginary parts + + for (int i = 0; i < n; i++) { + input[0][0][i] = Double.parseDouble(values[i]); + expected[0][i] = Double.parseDouble(values[n + i]); // Real part + expected[1][i] = Double.parseDouble(values[n * 2 + i]); // Imaginary part + } + + long startTime = System.nanoTime(); + MatrixBlock[] actualBlocks = fft(input); + long endTime = System.nanoTime(); + + if(lineNumber > 1000){ + totalTime += (endTime - startTime); + numCalculations++; + + if (numCalculations % progressInterval == 0) { + double averageTime = (totalTime / 1e6) / numCalculations; // Average time in milliseconds + System.out.println("fft(double[][][] in): Average execution time after " + numCalculations + " calculations: " + String.format("%.8f", averageTime/1000) + " s"); + } + } + + // Validate the FFT results + validateFftResults(expected, actualBlocks, lineNumber); + } + + reader.close(); + + } + + private void validateFftResults(double[][] expected, MatrixBlock[] actualBlocks, int lineNumber) { + int length = expected[0].length; + for (int i = 0; i < length; i++) { + double realActual = actualBlocks[0].getValueDenseUnsafe(0, i); + double imagActual = actualBlocks[1].getValueDenseUnsafe(0, i); + assertEquals("Mismatch in real part at index " + i + " in line " + lineNumber, expected[0][i], realActual, 1e-9); + assertEquals("Mismatch in imaginary part at index " + i + " in line " + lineNumber, expected[1][i], imagActual, 1e-9); + } + if(lineNumber % progressInterval == 0){ + System.out.println("fft(double[][][] in): Finished processing line " + lineNumber); + } + + } + + @Test + public void testFftExecutionTime() throws IOException { + String filename = "fft_data.csv"; // Path to your CSV file + BufferedReader reader = new BufferedReader(new FileReader(filename)); + String line; + int lineNumber = 0; + long totalTime = 0; // Total time for all FFT computations + int numCalculations = 0; // Number of FFT computations + + while ((line = reader.readLine()) != null) { + lineNumber++; + String[] values = line.split(","); + int n = values.length / 3; + double[][][] input = new double[2][1][n]; + + for (int i = 0; i < n; i++) { + input[0][0][i] = Double.parseDouble(values[i]); // Real part + input[1][0][i] = Double.parseDouble(values[n + i]); // Imaginary part + } + + long startTime = System.nanoTime(); + fft(input, false); + long endTime = System.nanoTime(); + if(lineNumber > 1000){ + totalTime += (endTime - startTime); + numCalculations++; + + if (numCalculations % progressInterval == 0) { + double averageTime = (totalTime / 1e6) / numCalculations; // Average time in milliseconds + System.out.println("fft(double[][][] in, boolean calcInv) Average execution time after " + numCalculations + " calculations: " + String.format("%.8f", averageTime/1000) + " s"); + } + } + } + + reader.close(); + } + + @Test + public void testFftExecutionTimeOfOneDimFFT() throws IOException { + String filename = "fft_data.csv"; // Path to your CSV file + BufferedReader reader = new BufferedReader(new FileReader(filename)); + String line; + int lineNumber = 0; + long totalTime = 0; // Total time for all FFT computations + int numCalculations = 0; // Number of FFT computations + + while ((line = reader.readLine()) != null) { + lineNumber++; + String[] values = line.split(","); + int n = values.length / 2; + double[][] input = new double[2][n]; // First row for real, second row for imaginary parts + + for (int i = 0; i < n; i++) { + input[0][i] = Double.parseDouble(values[i]); // Real part + input[1][i] = Double.parseDouble(values[n + i]); // Imaginary part + } + + long startTime = System.nanoTime(); + fft_one_dim(input); + long endTime = System.nanoTime(); + if(lineNumber > 1000){ + totalTime += (endTime - startTime); + numCalculations++; + + if (numCalculations % progressInterval == 0) { + double averageTime = (totalTime / 1e6) / numCalculations; // Average time in milliseconds + System.out.println("fft_one_dim: Average execution time after " + numCalculations + " calculations: " + String.format("%.8f", averageTime/1000) + " s "); + } + } + } + + reader.close(); + } + + + // prior to executing this test it is necessary to run the Numpy Script in FourierTestData.py and add the generated file to the root of the project. + @Test + public void testIfftWithRealNumpyData() throws IOException { + String filename = "ifft_data.csv"; // Path to your CSV file + BufferedReader reader = new BufferedReader(new FileReader(filename)); + String line; + int lineNumber = 0; + + while ((line = reader.readLine()) != null) { + lineNumber++; + String[] values = line.split(","); + int n = values.length / 3; + double[][][] input = new double[2][1][n]; + double[][] expected = new double[2][n]; // First row for real, second row for imaginary parts + + for (int i = 0; i < n; i++) { + input[0][0][i] = Double.parseDouble(values[i]); // Real part of input + // Imaginary part of input is assumed to be 0 + expected[0][i] = Double.parseDouble(values[n + i]); // Real part of expected output + expected[1][i] = Double.parseDouble(values[n * 2 + i]); // Imaginary part of expected output + } + + double[][][] actualResult = fft(input, true); // Perform IFFT + + // Validate the IFFT results + validateFftResults(expected, actualResult, lineNumber); + } + + reader.close(); + } + + private void validateFftResults(double[][] expected, double[][][] actualResult, int lineNumber) { + int length = expected[0].length; + for (int i = 0; i < length; i++) { + double realActual = actualResult[0][0][i]; + double imagActual = actualResult[1][0][i]; + assertEquals("Mismatch in real part at index " + i + " in line " + lineNumber, expected[0][i], realActual, 1e-9); + assertEquals("Mismatch in imaginary part at index " + i + " in line " + lineNumber, expected[1][i], imagActual, 1e-9); + } + if(lineNumber % progressInterval == 0){ + System.out.println("ifft(real input): Finished processing line " + lineNumber); + } + + } + + + + @Test +public void testIfftWithComplexNumpyData() throws IOException { Review Comment: resolved -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: dev-unsubscr...@systemds.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org