This is an automated email from the ASF dual-hosted git repository.

rkk pushed a commit to branch tmp-stv
in repository https://gitbox.apache.org/repos/asf/sdap-nexus.git

commit b25ab4e6bfd1c7f641c7c09b20b586622efd63c4
Author: rileykk <[email protected]>
AuthorDate: Wed Sep 27 10:04:47 2023 -0700

    Image rendering
---
 analysis/webservice/algorithms/Tomogram.py | 157 +++++++++++++++++++++++------
 analysis/webservice/algorithms/__init__.py |   1 +
 data-access/nexustiles/nexustiles.py       |   2 +
 3 files changed, 128 insertions(+), 32 deletions(-)

diff --git a/analysis/webservice/algorithms/Tomogram.py 
b/analysis/webservice/algorithms/Tomogram.py
index d43a3a8..0507e42 100644
--- a/analysis/webservice/algorithms/Tomogram.py
+++ b/analysis/webservice/algorithms/Tomogram.py
@@ -14,17 +14,16 @@
 # limitations under the License.
 
 import logging
+from io import BytesIO
 from typing import Dict, Literal, Union
-from numbers import Number
+
+import matplotlib.pyplot as plt
 import numpy as np
 import xarray as xr
-import matplotlib.pyplot as plt
 from webservice.NexusHandler import nexus_handler
 from webservice.algorithms.NexusCalcHandler import NexusCalcHandler
-from webservice.algorithms.DataInBoundsSearch import 
DataInBoundsSearchCalcHandlerImpl as Subset
 from webservice.webmodel import NexusResults, NexusProcessingException
 
-
 logger = logging.getLogger(__name__)
 
 
@@ -32,16 +31,19 @@ class TomogramBaseClass(NexusCalcHandler):
     def __init__(
             self,
             tile_service_factory,
-            slice_bounds: Dict[Literal['lat', 'lon', 'elevation'], 
Union[slice, Number]],
             **kwargs
     ):
         NexusCalcHandler.__init__(self, tile_service_factory)
-        self.__slice_bounds = slice_bounds
-        self.__margin = kwargs.get('margin', 0.001)
+        self.__slice_bounds = None
+        self.__margin = None
 
-        # slice bounds: dict relating dimension names to desired slicing
-        # When dealing with multi-var tiles, optional parameter to pick 
variable to plot,
-        # otherwise issue warning and pick the first one
+    def _set_params(
+            self,
+            slice_bounds: Dict[Literal['lat', 'lon', 'elevation'], 
Union[slice, float]],
+            margin: float
+    ):
+        self.__slice_bounds = slice_bounds
+        self.__margin = margin
 
     def parse_args(self, compute_options):
         try:
@@ -53,13 +55,10 @@ class TomogramBaseClass(NexusCalcHandler):
 
         return ds, parameter_s
 
-    def do_subset(self, compute_options):
+    def do_subset(self, ds, parameter):
         tile_service = self._get_tile_service()
 
-        ds, parameter = self.parse_args(compute_options)
-
         bounds = self.__slice_bounds
-        sel = {}
 
         if isinstance(bounds['lat'], slice):
             min_lat = bounds['lat'].start
@@ -68,8 +67,6 @@ class TomogramBaseClass(NexusCalcHandler):
             min_lat = bounds['lat'] - self.__margin
             max_lat = bounds['lat'] + self.__margin
 
-            sel = dict(lat=bounds['lat'])
-
         if isinstance(bounds['lon'], slice):
             min_lon = bounds['lon'].start
             max_lon = bounds['lon'].stop
@@ -77,8 +74,6 @@ class TomogramBaseClass(NexusCalcHandler):
             min_lon = bounds['lon'] - self.__margin
             max_lon = bounds['lon'] + self.__margin
 
-            sel = dict(lon=bounds['lon'])
-
         if isinstance(bounds['elevation'], slice):
             min_elevation = bounds['elevation'].start
             max_elevation = bounds['elevation'].stop
@@ -86,8 +81,6 @@ class TomogramBaseClass(NexusCalcHandler):
             min_elevation = bounds['elevation'] - self.__margin
             max_elevation = bounds['elevation'] + self.__margin
 
-            sel = dict(elevation=bounds['elevation'])
-
         tiles = tile_service.find_tiles_in_box(
             min_lat, max_lat, min_lon, max_lon,
             ds=ds,
@@ -100,7 +93,7 @@ class TomogramBaseClass(NexusCalcHandler):
 
         data = []
 
-        for i in range(len(tiles)-1, -1, -1): # tile in tiles:
+        for i in range(len(tiles)-1, -1, -1):
             tile = tiles.pop(i)
 
             tile_id = tile.tile_id
@@ -144,14 +137,14 @@ class TomogramBaseClass(NexusCalcHandler):
         return data
 
 
-# @nexus_handler
-class LatitudeTomogramImpl(TomogramBaseClass):
-    pass
-
-
-# @nexus_handler
-class LongitudeTomogramImpl(TomogramBaseClass):
-    pass
+# # @nexus_handler
+# class LatitudeTomogramImpl(TomogramBaseClass):
+#     pass
+#
+#
+# # @nexus_handler
+# class LongitudeTomogramImpl(TomogramBaseClass):
+#     pass
 
 
 @nexus_handler
@@ -174,12 +167,17 @@ class ElevationTomogramImpl(TomogramBaseClass):
             "name": "Bounding box",
             "type": "comma-delimited float",
             "description": "Minimum (Western) Longitude, Minimum (Southern) 
Latitude, "
-                           "Maximum (Eastern) Longitude, Maximum (Northern) 
Latitude. Required if 'metadataFilter' not provided"
+                           "Maximum (Eastern) Longitude, Maximum (Northern) 
Latitude. Required."
         },
         "elevation": {
             "name": "Slice elevation",
             "type": "float",
-            "description": "The "
+            "description": "The desired elevation of the tomogram slice"
+        },
+        "margin": {
+            "name": "Margin",
+            "type": "float",
+            "description": "Margin +/- desired elevation to include in output. 
Default: 0.5m"
         }
     }
     singleton = True
@@ -188,7 +186,102 @@ class ElevationTomogramImpl(TomogramBaseClass):
         TomogramBaseClass.__init__(self, tile_service_factory, **kwargs)
 
     def parse_args(self, compute_options):
-        pass
+        try:
+            bounding_poly = compute_options.get_bounding_polygon()
+        except:
+            raise NexusProcessingException(reason='Missing required parameter: 
b', code=400)
+
+        elevation = compute_options.get_float_arg('elevation', None)
+
+        if elevation is None:
+            raise NexusProcessingException(reason='Missing required parameter: 
elevation', code=400)
+
+        margin = compute_options.get_float_arg('margin', 0.5)
+
+        ds, parameter = super().parse_args(compute_options)
+
+        return ds, parameter, bounding_poly, elevation, margin
 
     def calc(self, computeOptions, **args):
+        ds, parameter, bounding_poly, elevation, margin = 
self.parse_args(computeOptions)
+
+        min_lat = bounding_poly.bounds[1]
+        max_lat = bounding_poly.bounds[3]
+        min_lon = bounding_poly.bounds[0]
+        max_lon = bounding_poly.bounds[2]
+
+        slices = dict(
+            lat=slice(min_lat, max_lat),
+            lon=slice(min_lon, max_lon),
+            elevation=float(elevation)
+        )
+
+        self._set_params(slices, margin)
+        data_in_bounds = self.do_subset(ds, parameter)
+
+        lats = np.unique([d['latitude'] for d in data_in_bounds])
+        lons = np.unique([d['longitude'] for d in data_in_bounds])
+
+        vals = np.empty((len(lats), len(lons)))
+
+        data_dict = {(d['latitude'], d['longitude']): d['data'] for d in 
data_in_bounds}
+
+        for i, lat in enumerate(lats):
+            for j, lon in enumerate(lons):
+                vals[i, j] = data_dict.get((lat, lon), np.nan)
+
+        ds = xr.Dataset(
+            data_vars=dict(
+                tomo=(('latitude', 'longitude'), vals)
+            ),
+            coords=dict(
+                latitude=(['latitude'], lats),
+                longitude=(['longitude'], lons)
+            ),
+            attrs=dict(
+                ds=ds,
+                elevation=elevation,
+                margin=margin
+            )
+        )
+
+        result = ElevationTomoResults(
+            (vals, lats, lons, dict(ds=ds, elevation=elevation, 
margin=margin)),
+        )
+
+        return result
+
+class ElevationTomoResults(NexusResults):
+    def __init__(self, results=None, meta=None, stats=None, 
computeOptions=None, status_code=200, **args):
+        NexusResults.__init__(self, results, meta, stats, computeOptions, 
status_code, **args)
+
+    def toImage(self):
+        data, lats, lons, attrs = self.results()
+
+        lats = lats.tolist()
+        lons = lons.tolist()
+
+        plt.figure(figsize=(15,11))
+        plt.imshow(np.squeeze(10*np.log10(data)), vmax=-10, vmin=-30)
+        plt.colorbar(label=f'Tomogram, z={attrs["elevation"]} ± 
{attrs["margin"]} m (dB)')
+        plt.title(f'Tomogram @ {attrs["elevation"]} ± {attrs["margin"]} m 
elevation')
+        plt.ylabel('Latitude')
+        plt.xlabel('Longitude')
+
+        xticks, xlabels = plt.xticks()
+        xlabels = [f'{lons[int(t)]:.4f}' if int(t) in range(len(lons)) else '' 
for t in xticks]
+        plt.xticks(xticks, xlabels, rotation=-90)
+
+        yticks, ylabels = plt.yticks()
+        ylabels = [f'{lats[int(t)]:.4f}' if int(t) in range(len(lats)) else '' 
for t in yticks]
+        plt.yticks(yticks, ylabels, )
+
+        buffer = BytesIO()
+
+        plt.savefig(buffer, format='png', facecolor='white')
+
+        buffer.seek(0)
+        return buffer.read()
+
+    def toJson(self):
         pass
diff --git a/analysis/webservice/algorithms/__init__.py 
b/analysis/webservice/algorithms/__init__.py
index 6063009..7ac3ed8 100644
--- a/analysis/webservice/algorithms/__init__.py
+++ b/analysis/webservice/algorithms/__init__.py
@@ -30,3 +30,4 @@ from . import TileSearch
 from . import TimeAvgMap
 from . import TimeSeries
 from . import TimeSeriesSolr
+from . import Tomogram
diff --git a/data-access/nexustiles/nexustiles.py 
b/data-access/nexustiles/nexustiles.py
index 85a9e0f..ffa9b26 100644
--- a/data-access/nexustiles/nexustiles.py
+++ b/data-access/nexustiles/nexustiles.py
@@ -462,8 +462,10 @@ class NexusTileService(object):
 
                 num_vars = len(tile.data)
                 multi_data_mask = np.repeat(data_mask[np.newaxis, ...], 
num_vars, axis=0)
+                multi_data_mask = np.broadcast_arrays(multi_data_mask, 
tile.data)[0]
                 tile.data = ma.masked_where(multi_data_mask, tile.data)
             else:
+                data_mask = np.broadcast_arrays(data_mask, tile.data)[0]
                 tile.data = ma.masked_where(data_mask, tile.data)
 
         tiles[:] = [tile for tile in tiles if not tile.data.mask.all()]

Reply via email to