This is an automated email from the ASF dual-hosted git repository. tloubrieu pushed a commit to branch SDAP-268 in repository https://gitbox.apache.org/repos/asf/incubator-sdap-nexus.git
commit f1663a5be5acc5f5df1d279d08915170059ee085 Author: Eamon Ford <[email protected]> AuthorDate: Mon Jul 27 15:04:17 2020 -0700 pass factory method to nexuscalchandlers to create tile service in spark nodes --- analysis/webservice/algorithms/NexusCalcHandler.py | 16 +++++----- .../webservice/algorithms_spark/ClimMapSpark.py | 8 ++--- .../webservice/algorithms_spark/CorrMapSpark.py | 29 +++++++++--------- .../DailyDifferenceAverageSpark.py | 23 +++++++------- .../webservice/algorithms_spark/HofMoellerSpark.py | 35 +++++++++++++--------- .../algorithms_spark/MaximaMinimaSpark.py | 10 +++---- .../algorithms_spark/NexusCalcSparkHandler.py | 12 ++++---- .../webservice/algorithms_spark/TimeAvgMapSpark.py | 9 +++--- .../webservice/algorithms_spark/TimeSeriesSpark.py | 13 ++++---- .../webservice/algorithms_spark/VarianceSpark.py | 16 +++++----- analysis/webservice/webapp.py | 25 ++++++++++------ 11 files changed, 107 insertions(+), 89 deletions(-) diff --git a/analysis/webservice/algorithms/NexusCalcHandler.py b/analysis/webservice/algorithms/NexusCalcHandler.py index b5f220f..bea0842 100644 --- a/analysis/webservice/algorithms/NexusCalcHandler.py +++ b/analysis/webservice/algorithms/NexusCalcHandler.py @@ -22,13 +22,15 @@ class NexusCalcHandler(object): if "params" not in cls.__dict__: raise Exception("Property 'params' has not been defined") - def __init__(self, algorithm_config=None, skipCassandra=False, skipSolr=False): - self.algorithm_config = algorithm_config - self._skipCassandra = skipCassandra - self._skipSolr = skipSolr - self._tile_service = NexusTileService(skipDatastore=self._skipCassandra, - skipMetadatastore=self._skipSolr, - config=self.algorithm_config) + def __init__(self, tile_service_factory, skipCassandra=False, skipSolr=False): + # self.algorithm_config = algorithm_config + # self._skipCassandra = skipCassandra + # self._skipSolr = skipSolr + # self._tile_service = NexusTileService(skipDatastore=self._skipCassandra, + # skipMetadatastore=self._skipSolr, + # config=self.algorithm_config) + self._tile_service_factory = tile_service_factory + self._tile_service = tile_service_factory() def _get_tile_service(self): return self._tile_service diff --git a/analysis/webservice/algorithms_spark/ClimMapSpark.py b/analysis/webservice/algorithms_spark/ClimMapSpark.py index e870a2a..78f11f8 100644 --- a/analysis/webservice/algorithms_spark/ClimMapSpark.py +++ b/analysis/webservice/algorithms_spark/ClimMapSpark.py @@ -25,7 +25,7 @@ from nexustiles.nexustiles import NexusTileService from webservice.NexusHandler import nexus_handler, DEFAULT_PARAMETERS_SPEC from webservice.algorithms_spark.NexusCalcSparkHandler import NexusCalcSparkHandler from webservice.webmodel import NexusResults, NexusProcessingException, NoDataException - +from functools import partial @nexus_handler class ClimMapNexusSparkHandlerImpl(NexusCalcSparkHandler): @@ -35,14 +35,14 @@ class ClimMapNexusSparkHandlerImpl(NexusCalcSparkHandler): params = DEFAULT_PARAMETERS_SPEC @staticmethod - def _map(tile_in_spark): + def _map(tile_service_factory, tile_in_spark): tile_bounds = tile_in_spark[0] (min_lat, max_lat, min_lon, max_lon, min_y, max_y, min_x, max_x) = tile_bounds startTime = tile_in_spark[1] endTime = tile_in_spark[2] ds = tile_in_spark[3] - tile_service = NexusTileService() + tile_service = tile_service_factory() # print 'Started tile', tile_bounds # sys.stdout.flush() tile_inbounds_shape = (max_y - min_y + 1, max_x - min_x + 1) @@ -196,7 +196,7 @@ class ClimMapNexusSparkHandlerImpl(NexusCalcSparkHandler): spark_nparts = self._spark_nparts(nparts_requested) self.log.info('Using {} partitions'.format(spark_nparts)) rdd = self._sc.parallelize(nexus_tiles_spark, spark_nparts) - sum_count_part = rdd.map(self._map) + sum_count_part = rdd.map(partial(self._map, self._tile_service_factory)) sum_count = \ sum_count_part.combineByKey(lambda val: val, lambda x, val: (x[0] + val[0], diff --git a/analysis/webservice/algorithms_spark/CorrMapSpark.py b/analysis/webservice/algorithms_spark/CorrMapSpark.py index 1af8cab..4d2c4fe 100644 --- a/analysis/webservice/algorithms_spark/CorrMapSpark.py +++ b/analysis/webservice/algorithms_spark/CorrMapSpark.py @@ -13,15 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. - import json -import math -import logging from datetime import datetime +from functools import partial + import numpy as np -from nexustiles.nexustiles import NexusTileService -# from time import time from webservice.NexusHandler import nexus_handler, DEFAULT_PARAMETERS_SPEC from webservice.algorithms_spark.NexusCalcSparkHandler import NexusCalcSparkHandler from webservice.webmodel import NexusProcessingException, NexusResults, NoDataException @@ -35,7 +32,7 @@ class CorrMapNexusSparkHandlerImpl(NexusCalcSparkHandler): params = DEFAULT_PARAMETERS_SPEC @staticmethod - def _map(tile_in): + def _map(tile_service_factory, tile_in): # Unpack input tile_bounds, start_time, end_time, ds = tile_in (min_lat, max_lat, min_lon, max_lon, @@ -60,7 +57,7 @@ class CorrMapNexusSparkHandlerImpl(NexusCalcSparkHandler): # print 'days_at_a_time = ', days_at_a_time t_incr = 86400 * days_at_a_time - tile_service = NexusTileService() + tile_service = tile_service_factory # Compute the intermediate summations needed for the Pearson # Correlation Coefficient. We use a one-pass online algorithm @@ -194,12 +191,12 @@ class CorrMapNexusSparkHandlerImpl(NexusCalcSparkHandler): self.log.debug('nlats={0}, nlons={1}'.format(self._nlats, self._nlons)) daysinrange = self._get_tile_service().find_days_in_range_asc(self._minLat, - self._maxLat, - self._minLon, - self._maxLon, - self._ds[0], - self._startTime, - self._endTime) + self._maxLat, + self._minLon, + self._maxLon, + self._ds[0], + self._startTime, + self._endTime) ndays = len(daysinrange) if ndays == 0: raise NoDataException(reason="No data found for selected timeframe") @@ -224,7 +221,9 @@ class CorrMapNexusSparkHandlerImpl(NexusCalcSparkHandler): max_time_parts = 72 num_time_parts = min(max_time_parts, ndays) - spark_part_time_ranges = np.tile(np.array([a[[0,-1]] for a in np.array_split(np.array(daysinrange), num_time_parts)]), (len(nexus_tiles_spark),1)) + spark_part_time_ranges = np.tile( + np.array([a[[0, -1]] for a in np.array_split(np.array(daysinrange), num_time_parts)]), + (len(nexus_tiles_spark), 1)) nexus_tiles_spark = np.repeat(nexus_tiles_spark, num_time_parts, axis=0) nexus_tiles_spark[:, 1:3] = spark_part_time_ranges @@ -233,7 +232,7 @@ class CorrMapNexusSparkHandlerImpl(NexusCalcSparkHandler): self.log.info('Using {} partitions'.format(spark_nparts)) rdd = self._sc.parallelize(nexus_tiles_spark, spark_nparts) - sum_tiles_part = rdd.map(self._map) + sum_tiles_part = rdd.map(partial(self._map, self._tile_service_factory)) # print "sum_tiles_part = ",sum_tiles_part.collect() sum_tiles = \ sum_tiles_part.combineByKey(lambda val: val, diff --git a/analysis/webservice/algorithms_spark/DailyDifferenceAverageSpark.py b/analysis/webservice/algorithms_spark/DailyDifferenceAverageSpark.py index 51be431..344927f 100644 --- a/analysis/webservice/algorithms_spark/DailyDifferenceAverageSpark.py +++ b/analysis/webservice/algorithms_spark/DailyDifferenceAverageSpark.py @@ -134,15 +134,18 @@ class DailyDifferenceAverageNexusImplSpark(NexusCalcSparkHandler): # Get tile ids in box tile_ids = [tile.tile_id for tile in self._get_tile_service().find_tiles_in_polygon(bounding_polygon, dataset, - start_seconds_from_epoch, end_seconds_from_epoch, - fetch_data=False, fl='id', - sort=['tile_min_time_dt asc', 'tile_min_lon asc', - 'tile_min_lat asc'], rows=5000)] + start_seconds_from_epoch, end_seconds_from_epoch, + fetch_data=False, fl='id', + sort=['tile_min_time_dt asc', 'tile_min_lon asc', + 'tile_min_lat asc'], rows=5000)] # Call spark_matchup - self.log.debug("Calling Spark Driver") try: - spark_result = spark_anomolies_driver(tile_ids, wkt.dumps(bounding_polygon), dataset, climatology, + spark_result = spark_anomalies_driver(self._tile_service_factory, + tile_ids, + wkt.dumps(bounding_polygon), + dataset, + climatology, sc=self._sc) except Exception as e: self.log.exception(e) @@ -264,7 +267,7 @@ def determine_parllelism(num_tiles): return num_partitions -def spark_anomolies_driver(tile_ids, bounding_wkt, dataset, climatology, sc=None): +def spark_anomalies_driver(tile_service_driver, tile_ids, bounding_wkt, dataset, climatology, sc=None): from functools import partial with DRIVER_LOCK: @@ -297,7 +300,7 @@ def spark_anomolies_driver(tile_ids, bounding_wkt, dataset, climatology, sc=None return sum_cnt_var_tuple[0] / sum_cnt_var_tuple[1], np.sqrt(sum_cnt_var_tuple[2]) result = rdd \ - .mapPartitions(partial(calculate_diff, bounding_wkt=bounding_wkt_b, dataset=dataset_b, + .mapPartitions(partial(calculate_diff, tile_service_driver, bounding_wkt=bounding_wkt_b, dataset=dataset_b, climatology=climatology_b)) \ .reduceByKey(add_tuple_elements) \ .mapValues(compute_avg_and_std) \ @@ -307,7 +310,7 @@ def spark_anomolies_driver(tile_ids, bounding_wkt, dataset, climatology, sc=None return result -def calculate_diff(tile_ids, bounding_wkt, dataset, climatology): +def calculate_diff(tile_service_factory, tile_ids, bounding_wkt, dataset, climatology): from itertools import chain # Construct a list of generators that yield (day, sum, count, variance) @@ -316,7 +319,7 @@ def calculate_diff(tile_ids, bounding_wkt, dataset, climatology): tile_ids = list(tile_ids) if len(tile_ids) == 0: return [] - tile_service = NexusTileService() + tile_service = tile_service_factory() for tile_id in tile_ids: # Get the dataset tile diff --git a/analysis/webservice/algorithms_spark/HofMoellerSpark.py b/analysis/webservice/algorithms_spark/HofMoellerSpark.py index c4bc019..6616ae2 100644 --- a/analysis/webservice/algorithms_spark/HofMoellerSpark.py +++ b/analysis/webservice/algorithms_spark/HofMoellerSpark.py @@ -14,7 +14,6 @@ # limitations under the License. import itertools -import logging from cStringIO import StringIO from datetime import datetime from functools import partial @@ -25,8 +24,8 @@ import numpy as np import shapely.geometry from matplotlib import cm from matplotlib.ticker import FuncFormatter -from nexustiles.nexustiles import NexusTileService from pytz import timezone + from webservice.NexusHandler import nexus_handler from webservice.algorithms_spark.NexusCalcSparkHandler import NexusCalcSparkHandler from webservice.webmodel import NexusResults, NoDataException, NexusProcessingException @@ -41,12 +40,12 @@ LONGITUDE = 1 class HofMoellerCalculator(object): @staticmethod - def hofmoeller_stats(metrics_callback, tile_in_spark): + def hofmoeller_stats(tile_service_factory, metrics_callback, tile_in_spark): (latlon, tile_id, index, min_lat, max_lat, min_lon, max_lon) = tile_in_spark - tile_service = NexusTileService() + tile_service = tile_service_factory() try: # Load the dataset tile tile = tile_service.find_tile_by_id(tile_id, metrics_callback=metrics_callback)[0] @@ -263,7 +262,7 @@ def hof_tuple_to_dict(t, avg_var_name): 'min': t[7]} -def spark_driver(sc, latlon, nexus_tiles_spark, metrics_callback): +def spark_driver(sc, latlon, tile_service_factory, nexus_tiles_spark, metrics_callback): # Parallelize list of tile ids rdd = sc.parallelize(nexus_tiles_spark, determine_parllelism(len(nexus_tiles_spark))) if latlon == 0: @@ -279,7 +278,7 @@ def spark_driver(sc, latlon, nexus_tiles_spark, metrics_callback): # the value is a tuple of intermediate statistics for the specified # coordinate within a single NEXUS tile. metrics_callback(partitions=rdd.getNumPartitions()) - results = rdd.flatMap(partial(HofMoellerCalculator.hofmoeller_stats, metrics_callback)) + results = rdd.flatMap(partial(HofMoellerCalculator.hofmoeller_stats, tile_service_factory, metrics_callback)) # Combine tuples across tiles with input key = (time, lat|lon) # Output a key value pair with key = (time) @@ -349,15 +348,19 @@ class LatitudeTimeHoffMoellerSparkHandlerImpl(BaseHoffMoellerSparkHandlerImpl): nexus_tiles_spark = [(self._latlon, tile.tile_id, x, min_lat, max_lat, min_lon, max_lon) for x, tile in enumerate(self._get_tile_service().find_tiles_in_box(min_lat, max_lat, min_lon, max_lon, - ds, start_time, end_time, - metrics_callback=metrics_record.record_metrics, - fetch_data=False))] + ds, start_time, end_time, + metrics_callback=metrics_record.record_metrics, + fetch_data=False))] print ("Got {} tiles".format(len(nexus_tiles_spark))) if len(nexus_tiles_spark) == 0: raise NoDataException(reason="No data found for selected timeframe") - results = spark_driver(self._sc, self._latlon, nexus_tiles_spark, metrics_record.record_metrics) + results = spark_driver(self._sc, + self._latlon, + self._tile_service_factory, + nexus_tiles_spark, + metrics_record.record_metrics) results = filter(None, results) results = sorted(results, key=lambda entry: entry['time']) for i in range(len(results)): @@ -400,15 +403,19 @@ class LongitudeTimeHoffMoellerSparkHandlerImpl(BaseHoffMoellerSparkHandlerImpl): nexus_tiles_spark = [(self._latlon, tile.tile_id, x, min_lat, max_lat, min_lon, max_lon) for x, tile in enumerate(self._get_tile_service().find_tiles_in_box(min_lat, max_lat, min_lon, max_lon, - ds, start_time, end_time, - metrics_callback=metrics_record.record_metrics, - fetch_data=False))] + ds, start_time, end_time, + metrics_callback=metrics_record.record_metrics, + fetch_data=False))] print ("Got {} tiles".format(len(nexus_tiles_spark))) if len(nexus_tiles_spark) == 0: raise NoDataException(reason="No data found for selected timeframe") - results = spark_driver(self._sc, self._latlon, nexus_tiles_spark, metrics_record.record_metrics) + results = spark_driver(self._sc, + self._latlon, + nexus_tiles_spark, + self._tile_service_factory, + metrics_record.record_metrics) results = filter(None, results) results = sorted(results, key=lambda entry: entry["time"]) diff --git a/analysis/webservice/algorithms_spark/MaximaMinimaSpark.py b/analysis/webservice/algorithms_spark/MaximaMinimaSpark.py index 3bd9698..5b4bd83 100644 --- a/analysis/webservice/algorithms_spark/MaximaMinimaSpark.py +++ b/analysis/webservice/algorithms_spark/MaximaMinimaSpark.py @@ -14,13 +14,11 @@ # limitations under the License. -import math -import logging from datetime import datetime +from functools import partial import numpy as np import shapely.geometry -from nexustiles.nexustiles import NexusTileService from pytz import timezone from webservice.NexusHandler import nexus_handler @@ -207,7 +205,7 @@ class MaximaMinimaSparkHandlerImpl(NexusCalcSparkHandler): self.log.info('Using {} partitions'.format(spark_nparts)) rdd = self._sc.parallelize(nexus_tiles_spark, spark_nparts) - max_min_part = rdd.map(self._map) + max_min_part = rdd.map(partial(self._map, self._tile_service_factory)) max_min_count = \ max_min_part.combineByKey(lambda val: val, lambda x, val: (np.maximum(x[0], val[0]), # Max @@ -283,7 +281,7 @@ class MaximaMinimaSparkHandlerImpl(NexusCalcSparkHandler): # this operates on only one nexus tile bound over time. Can assume all nexus_tiles are the same shape @staticmethod - def _map(tile_in_spark): + def _map(tile_service_factory, tile_in_spark): # tile_in_spark is a spatial tile that corresponds to nexus tiles of the same area tile_bounds = tile_in_spark[0] (min_lat, max_lat, min_lon, max_lon, @@ -291,7 +289,7 @@ class MaximaMinimaSparkHandlerImpl(NexusCalcSparkHandler): startTime = tile_in_spark[1] endTime = tile_in_spark[2] ds = tile_in_spark[3] - tile_service = NexusTileService() + tile_service = tile_service_factory() tile_inbounds_shape = (max_y - min_y + 1, max_x - min_x + 1) diff --git a/analysis/webservice/algorithms_spark/NexusCalcSparkHandler.py b/analysis/webservice/algorithms_spark/NexusCalcSparkHandler.py index 9e77887..fe3541a 100644 --- a/analysis/webservice/algorithms_spark/NexusCalcSparkHandler.py +++ b/analysis/webservice/algorithms_spark/NexusCalcSparkHandler.py @@ -8,6 +8,7 @@ from webservice.webmodel import NexusProcessingException logger = logging.getLogger(__name__) + class NexusCalcSparkHandler(NexusCalcHandler): class SparkJobContext(object): @@ -33,14 +34,15 @@ class NexusCalcSparkHandler(NexusCalcHandler): self.log.debug("Returning %s" % self.job_name) self.spark_job_stack.append(self.job_name) - def __init__(self, algorithm_config=None, sc=None, **kwargs): + def __init__(self, tile_service_factory, sc=None, **kwargs): import inspect - NexusCalcHandler.__init__(self, algorithm_config=algorithm_config, **kwargs) + NexusCalcHandler.__init__(self, tile_service_factory=tile_service_factory, **kwargs) self.spark_job_stack = [] self._sc = sc - max_concurrent_jobs = algorithm_config.getint("spark", "maxconcurrentjobs") if algorithm_config.has_section( - "spark") and algorithm_config.has_option("spark", "maxconcurrentjobs") else 10 + # max_concurrent_jobs = algorithm_config.getint("spark", "maxconcurrentjobs") if algorithm_config.has_section( + # "spark") and algorithm_config.has_option("spark", "maxconcurrentjobs") else 10 + max_concurrent_jobs = 10 self.spark_job_stack = list(["Job %s" % x for x in xrange(1, max_concurrent_jobs + 1)]) self.log = logging.getLogger(__name__) @@ -350,4 +352,4 @@ class NexusCalcSparkHandler(NexusCalcHandler): accumulator=self._sc.accumulator(0)), NumberMetricsField(key='reduce', description='Actual time to reduce results'), NumberMetricsField(key="actual_time", description="Total (actual) time") - ]) \ No newline at end of file + ]) diff --git a/analysis/webservice/algorithms_spark/TimeAvgMapSpark.py b/analysis/webservice/algorithms_spark/TimeAvgMapSpark.py index c668130..6231873 100644 --- a/analysis/webservice/algorithms_spark/TimeAvgMapSpark.py +++ b/analysis/webservice/algorithms_spark/TimeAvgMapSpark.py @@ -13,14 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging from datetime import datetime from functools import partial import numpy as np import shapely.geometry -from nexustiles.nexustiles import NexusTileService from pytz import timezone + from webservice.NexusHandler import nexus_handler from webservice.algorithms_spark.NexusCalcSparkHandler import NexusCalcSparkHandler from webservice.webmodel import NexusResults, NexusProcessingException, NoDataException @@ -198,7 +197,7 @@ class TimeAvgMapNexusSparkHandlerImpl(NexusCalcSparkHandler): rdd = self._sc.parallelize(nexus_tiles_spark, spark_nparts) metrics_record.record_metrics(partitions=rdd.getNumPartitions()) - sum_count_part = rdd.map(partial(self._map, metrics_record.record_metrics)) + sum_count_part = rdd.map(partial(self._map, self._tile_service_factory, metrics_record.record_metrics)) reduce_duration = 0 reduce_start = datetime.now() sum_count = sum_count_part.combineByKey(lambda val: val, @@ -264,14 +263,14 @@ class TimeAvgMapNexusSparkHandlerImpl(NexusCalcSparkHandler): endTime=end_time) @staticmethod - def _map(metrics_callback, tile_in_spark): + def _map(tile_service_factory, metrics_callback, tile_in_spark): tile_bounds = tile_in_spark[0] (min_lat, max_lat, min_lon, max_lon, min_y, max_y, min_x, max_x) = tile_bounds startTime = tile_in_spark[1] endTime = tile_in_spark[2] ds = tile_in_spark[3] - tile_service = NexusTileService() + tile_service = tile_service_factory() tile_inbounds_shape = (max_y - min_y + 1, max_x - min_x + 1) diff --git a/analysis/webservice/algorithms_spark/TimeSeriesSpark.py b/analysis/webservice/algorithms_spark/TimeSeriesSpark.py index bf5963e..43f7f6d 100644 --- a/analysis/webservice/algorithms_spark/TimeSeriesSpark.py +++ b/analysis/webservice/algorithms_spark/TimeSeriesSpark.py @@ -195,7 +195,10 @@ class TimeSeriesSparkHandlerImpl(NexusCalcSparkHandler): spark_nparts = self._spark_nparts(nparts_requested) self.log.info('Using {} partitions'.format(spark_nparts)) results, meta = spark_driver(daysinrange, bounding_polygon, - shortName, metrics_record.record_metrics, spark_nparts=spark_nparts, + shortName, + self._tile_service_factory, + metrics_record.record_metrics, + spark_nparts=spark_nparts, sc=self._sc) if apply_seasonal_cycle_filter: @@ -487,7 +490,7 @@ class TimeSeriesResults(NexusResults): return sio.getvalue() -def spark_driver(daysinrange, bounding_polygon, ds, metrics_callback, fill=-9999., +def spark_driver(daysinrange, bounding_polygon, ds, tile_service_factory, metrics_callback, fill=-9999., spark_nparts=1, sc=None): nexus_tiles_spark = [(bounding_polygon.wkt, ds, list(daysinrange_part), fill) @@ -497,14 +500,14 @@ def spark_driver(daysinrange, bounding_polygon, ds, metrics_callback, fill=-9999 # Launch Spark computations rdd = sc.parallelize(nexus_tiles_spark, spark_nparts) metrics_callback(partitions=rdd.getNumPartitions()) - results = rdd.flatMap(partial(calc_average_on_day, metrics_callback)).collect() + results = rdd.flatMap(partial(calc_average_on_day, tile_service_factory, metrics_callback)).collect() results = list(itertools.chain.from_iterable(results)) results = sorted(results, key=lambda entry: entry["time"]) return results, {} -def calc_average_on_day(metrics_callback, tile_in_spark): +def calc_average_on_day(tile_service_factory, metrics_callback, tile_in_spark): import shapely.wkt from datetime import datetime from pytz import timezone @@ -513,7 +516,7 @@ def calc_average_on_day(metrics_callback, tile_in_spark): (bounding_wkt, dataset, timestamps, fill) = tile_in_spark if len(timestamps) == 0: return [] - tile_service = NexusTileService() + tile_service = tile_service_factory() ds1_nexus_tiles = \ tile_service.get_tiles_bounded_by_polygon(shapely.wkt.loads(bounding_wkt), dataset, diff --git a/analysis/webservice/algorithms_spark/VarianceSpark.py b/analysis/webservice/algorithms_spark/VarianceSpark.py index 698385d..24ffbf0 100644 --- a/analysis/webservice/algorithms_spark/VarianceSpark.py +++ b/analysis/webservice/algorithms_spark/VarianceSpark.py @@ -14,13 +14,11 @@ # limitations under the License. -import math -import logging from datetime import datetime +from functools import partial import numpy as np import shapely.geometry -from nexustiles.nexustiles import NexusTileService from pytz import timezone from webservice.NexusHandler import nexus_handler @@ -207,7 +205,7 @@ class VarianceNexusSparkHandlerImpl(NexusCalcSparkHandler): self.log.info('Using {} partitions'.format(spark_nparts)) rdd = self._sc.parallelize(nexus_tiles_spark, spark_nparts) - sum_count_part = rdd.map(self._map) + sum_count_part = rdd.map(partial(self._map, self._tile_service_factory)) sum_count = \ sum_count_part.combineByKey(lambda val: val, lambda x, val: (x[0] + val[0], @@ -235,7 +233,7 @@ class VarianceNexusSparkHandlerImpl(NexusCalcSparkHandler): self.log.info('Using {} partitions'.format(spark_nparts)) rdd = self._sc.parallelize(nexus_tiles_spark, spark_nparts) - anomaly_squared_part = rdd.map(self._calc_variance) + anomaly_squared_part = rdd.map(partial(self._calc_variance, self._tile_service_factory)) anomaly_squared = \ anomaly_squared_part.combineByKey(lambda val: val, lambda x, val: (x[0] + val[0], @@ -303,7 +301,7 @@ class VarianceNexusSparkHandlerImpl(NexusCalcSparkHandler): endTime=end_time) @staticmethod - def _map(tile_in_spark): + def _map(tile_service_factory, tile_in_spark): # tile_in_spark is a spatial tile that corresponds to nexus tiles of the same area tile_bounds = tile_in_spark[0] (min_lat, max_lat, min_lon, max_lon, @@ -311,7 +309,7 @@ class VarianceNexusSparkHandlerImpl(NexusCalcSparkHandler): startTime = tile_in_spark[1] endTime = tile_in_spark[2] ds = tile_in_spark[3] - tile_service = NexusTileService() + tile_service = tile_service_factory() tile_inbounds_shape = (max_y - min_y + 1, max_x - min_x + 1) @@ -345,7 +343,7 @@ class VarianceNexusSparkHandlerImpl(NexusCalcSparkHandler): return tile_bounds, (sum_tile, cnt_tile) @staticmethod - def _calc_variance(tile_in_spark): + def _calc_variance(tile_service_factory, tile_in_spark): # tile_in_spark is a spatial tile that corresponds to nexus tiles of the same area tile_bounds = tile_in_spark[0] (min_lat, max_lat, min_lon, max_lon, @@ -354,7 +352,7 @@ class VarianceNexusSparkHandlerImpl(NexusCalcSparkHandler): endTime = tile_in_spark[2] ds = tile_in_spark[3] x_bar = tile_in_spark[4] - tile_service = NexusTileService() + tile_service = tile_service_factory() tile_inbounds_shape = (max_y - min_y + 1, max_x - min_x + 1) diff --git a/analysis/webservice/webapp.py b/analysis/webservice/webapp.py index bc232b4..3ee2f09 100644 --- a/analysis/webservice/webapp.py +++ b/analysis/webservice/webapp.py @@ -13,16 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import ConfigParser import importlib import logging import sys +from functools import partial + import pkg_resources import tornado.web -import webservice.algorithms_spark.NexusCalcSparkHandler from tornado.options import define, options, parse_command_line +import webservice.algorithms_spark.NexusCalcSparkHandler +from nexustiles.nexustiles import NexusTileService from webservice import NexusHandler from webservice.nexus_tornado.request.handlers import NexusRequestHandler @@ -101,6 +103,7 @@ if __name__ == "__main__": log.info("Initializing request ThreadPool to %s" % max_request_threads) request_thread_pool = tornado.concurrent.futures.ThreadPoolExecutor(max_request_threads) + tile_service_factory = partial(NexusTileService, False, False, algorithm_config) spark_context = None for clazzWrapper in NexusHandler.AVAILABLE_HANDLERS: if issubclass(clazzWrapper, webservice.algorithms_spark.NexusCalcSparkHandler.NexusCalcSparkHandler): @@ -110,14 +113,18 @@ if __name__ == "__main__": spark = SparkSession.builder.appName("nexus-analysis").getOrCreate() spark_context = spark.sparkContext - handlers.append( - (clazzWrapper.path, NexusRequestHandler, - dict(clazz=clazzWrapper, algorithm_config=algorithm_config, sc=spark_context, - thread_pool=request_thread_pool))) + handlers.append((clazzWrapper.path, + NexusRequestHandler, + dict(clazz=clazzWrapper, + tile_service_factory=tile_service_factory, + sc=spark_context, + thread_pool=request_thread_pool))) else: - handlers.append( - (clazzWrapper.path, NexusRequestHandler, - dict(clazz=clazzWrapper, algorithm_config=algorithm_config, thread_pool=request_thread_pool))) + handlers.append((clazzWrapper.path, + NexusRequestHandler, + dict(clazz=clazzWrapper, + tile_service_factory=tile_service_factory, + thread_pool=request_thread_pool))) class VersionHandler(tornado.web.RequestHandler):
