Added unit tests
Project: http://git-wip-us.apache.org/repos/asf/climate/repo Commit: http://git-wip-us.apache.org/repos/asf/climate/commit/98a67d13 Tree: http://git-wip-us.apache.org/repos/asf/climate/tree/98a67d13 Diff: http://git-wip-us.apache.org/repos/asf/climate/diff/98a67d13 Branch: refs/heads/master Commit: 98a67d130ad9b2c946f18dba90da6b68c51594de Parents: 198de48 Author: Alex Goodman <ago...@users.noreply.github.com> Authored: Mon Jul 25 16:14:14 2016 -0700 Committer: Alex Goodman <ago...@users.noreply.github.com> Committed: Mon Jul 25 16:14:14 2016 -0700 ---------------------------------------------------------------------- ocw/tests/test_dataset_loader.py | 168 ++++++++++++++++++++++++++++++++++ 1 file changed, 168 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/climate/blob/98a67d13/ocw/tests/test_dataset_loader.py ---------------------------------------------------------------------- diff --git a/ocw/tests/test_dataset_loader.py b/ocw/tests/test_dataset_loader.py new file mode 100644 index 0000000..da1e8e8 --- /dev/null +++ b/ocw/tests/test_dataset_loader.py @@ -0,0 +1,168 @@ +import unittest +import os +import copy +import netCDF4 +import numpy as np +from ocw.dataset import Dataset +from ocw.dataset_loader import DatasetLoader + +class TestDatasetLoader(unittest.TestCase): + def setUp(self): + # Read netCDF file + self.file_path = create_netcdf_object() + self.netCDF_file = netCDF4.Dataset(self.file_path, 'r') + self.latitudes = self.netCDF_file.variables['latitude'][:] + self.longitudes = self.netCDF_file.variables['longitude'][:] + self.times = self.netCDF_file.variables['time'][:] + self.alt_lats = self.netCDF_file.variables['alt_lat'][:] + self.alt_lons = self.netCDF_file.variables['alt_lon'][:] + self.values = self.netCDF_file.variables['value'][:] + self.values2 = self.values + 1 + + # Set up config + self.reference_config = {'data_source': 'local', + 'file_path': self.file_path, + 'variable_name': 'value'} + self.target_config = copy.deepcopy(self.reference_config) + self.no_data_source_config = {'file_path': self.file_path, + 'variable_name': 'value'} + self.new_data_source_config = {'data_source': 'foo', + 'lats': self.latitudes, + 'lons': self.longitudes, + 'times': self.times, + 'values': self.values2, + 'variable': 'value'} + + def tearDown(self): + os.remove(self.file_path) + + def testInputHasDataSource(self): + ''' + Make sure input data source is specified for each dataset to be loaded + ''' + with self.assertRaises(KeyError): + self.loader = DatasetLoader(self.reference_config, + self.no_data_source_config) + + def testReferenceHasDataSource(self): + ''' + Make sure ref data source is specified for each dataset to be loaded + ''' + with self.assertRaises(KeyError): + self.loader = DatasetLoader(self.reference_config, + self.target_config) + self.loader.set_reference(**self.no_data_source_config) + + def testTargetHasDataSource(self): + ''' + Make sure target data source is specified for each dataset to be loaded + ''' + with self.assertRaises(KeyError): + self.loader = DatasetLoader(self.reference_config, + self.target_config) + self.loader.add_target(**self.no_data_source_config) + + def testNewDataSource(self): + ''' + Ensures that custom data source loaders can be added + ''' + self.loader = DatasetLoader(self.new_data_source_config, + self.target_config) + + # Here the the data_source "foo" represents the Dataset constructor + self.loader.add_source_loader('foo', build_dataset) + self.loader.load_datasets() + self.assertEqual(self.loader.reference_dataset.origin['source'], + 'foo') + np.testing.assert_array_equal(self.loader.reference_dataset.values, + self.values2) + + def testExistingDataSource(self): + ''' + Ensures that existing data source loaders can be added + ''' + self.loader = DatasetLoader(self.reference_config, + self.target_config) + self.loader.load_datasets() + self.assertEqual(self.loader.reference_dataset.origin['source'], + 'local') + np.testing.assert_array_equal(self.loader.reference_dataset.values, + self.values) + + def testMultipleTargets(self): + ''' + Test for when multiple target dataset configs are specified + ''' + self.loader = DatasetLoader(self.reference_config, + [self.target_config, + self.new_data_source_config]) + + # Here the the data_source "foo" represents the Dataset constructor + self.loader.add_source_loader('foo', build_dataset) + self.loader.load_datasets() + self.assertEqual(self.loader.target_datasets[0].origin['source'], + 'local') + self.assertEqual(self.loader.target_datasets[1].origin['source'], + 'foo') + np.testing.assert_array_equal(self.loader.target_datasets[0].values, + self.values) + np.testing.assert_array_equal(self.loader.target_datasets[1].values, + self.values2) + +def build_dataset(*args, **kwargs): + ''' + Wrapper to Dataset constructor from fictitious 'foo' data_source. + ''' + origin = {'source': 'foo'} + return Dataset(*args, origin=origin, **kwargs) + +def create_netcdf_object(): + # To create the temporary netCDF file + file_path = '/tmp/temporaryNetcdf.nc' + netCDF_file = netCDF4.Dataset(file_path, 'w', format='NETCDF4') + # To create dimensions + netCDF_file.createDimension('lat_dim', 5) + netCDF_file.createDimension('lon_dim', 5) + netCDF_file.createDimension('time_dim', 3) + # To create variables + latitudes = netCDF_file.createVariable('latitude', 'd', ('lat_dim',)) + longitudes = netCDF_file.createVariable('longitude', 'd', ('lon_dim',)) + times = netCDF_file.createVariable('time', 'd', ('time_dim',)) + # unusual variable names to test optional arguments for Dataset constructor + alt_lats = netCDF_file.createVariable('alt_lat', 'd', ('lat_dim',)) + alt_lons = netCDF_file.createVariable('alt_lon', 'd', ('lon_dim',)) + alt_times = netCDF_file.createVariable('alt_time', 'd', ('time_dim',)) + values = netCDF_file.createVariable('value', 'd', + ('time_dim', + 'lat_dim', + 'lon_dim') + ) + + # To latitudes and longitudes for five values + latitudes_data = np.arange(5.) + longitudes_data = np.arange(150., 155.) + # Three months of data. + times_data = np.arange(3) + # Create 150 values + values_data = np.array([i for i in range(75)]) + # Reshape values to 4D array (level, time, lats, lons) + values_data = values_data.reshape(len(times_data), len(latitudes_data), + len(longitudes_data)) + + # Ingest values to netCDF file + latitudes[:] = latitudes_data + longitudes[:] = longitudes_data + times[:] = times_data + alt_lats[:] = latitudes_data + 10 + alt_lons[:] = longitudes_data - 10 + alt_times[:] = times_data + values[:] = values_data + # Assign time info to time variable + netCDF_file.variables['time'].units = 'months since 2001-01-01 00:00:00' + netCDF_file.variables['alt_time'].units = 'months since 2001-04-01 00:00:00' + netCDF_file.variables['value'].units = 'foo_units' + netCDF_file.close() + return file_path + +if __name__ == '__main__': + unittest.main()