This is an automated email from the ASF dual-hosted git repository.

rkk pushed a commit to branch tmp-stv
in repository https://gitbox.apache.org/repos/asf/sdap-nexus.git

commit b45846e1cf16953177735ca5844b4e23e640b4ad
Author: rileykk <[email protected]>
AuthorDate: Wed Dec 20 09:29:24 2023 -0800

    Additional arguments and render types
---
 analysis/webservice/algorithms/Tomogram3D.py | 382 ++++++++++++++++++++-------
 1 file changed, 283 insertions(+), 99 deletions(-)

diff --git a/analysis/webservice/algorithms/Tomogram3D.py 
b/analysis/webservice/algorithms/Tomogram3D.py
index 1bdccb5..cc3d97c 100644
--- a/analysis/webservice/algorithms/Tomogram3D.py
+++ b/analysis/webservice/algorithms/Tomogram3D.py
@@ -24,6 +24,7 @@ import matplotlib.pyplot as plt
 import numpy as np
 import pandas as pd
 import requests
+from matplotlib.colors import XKCD_COLORS
 from mpl_toolkits.basemap import Basemap
 from PIL import Image
 from webservice.NexusHandler import nexus_handler
@@ -88,6 +89,25 @@ class Tomogram3D(NexusCalcHandler):
             "description": "If output==GIF, specifies the duration of each 
frame in the animation in milliseconds. "
                            "Default: 100; Range: >=100"
         },
+        "filter": {
+            "name": "Voxel filter",
+            "type": "comma-delimited pair of numbers",
+            "description": "Pair of numbers min,max. If defined, will filter 
out all points in the tomogram whose value"
+                           " does not satisfy min <= v <= max. Can be left 
unbounded, eg: 'filter=-15,' would filter "
+                           "everything but points >= -15"
+        },
+        "renderType": {
+            "name": "Render Type",
+            "type": "string",
+            "description": "Type of render: Must be either \"scatter\" or 
\"peak\". Scatter will plot every voxel in "
+                           "the tomogram, colored by value, peak will plot a 
surface of all peaks for each lat/lon "
+                           "grid point, flat colored and shaded. Default: 
scatter"
+        },
+        "hideBasemap": {
+            "name": "Hide Basemap",
+            "type": "boolean",
+            "description": "If true, do not draw basemap beneath plot. This 
can be used to speed up render time a bit"
+        },
     }
     singleton = True
 
@@ -161,11 +181,49 @@ class Tomogram3D(NexusCalcHandler):
 
             orbit_elev, orbit_step = None, None
 
-        return ds, parameter_s, bounding_poly, min_elevation, max_elevation, 
(orbit_elev, orbit_step, frame_duration,
-                                                                              
view_azim, view_elev)
+        filter_arg = compute_options.get_argument('filter')
+
+        if filter_arg is not None:
+            try:
+                filter_arg = filter_arg.split(',')
+                assert len(filter_arg) == 2
+                vmin, vmax = filter_arg
+
+                if vmin != '':
+                    vmin = float(vmin)
+                else:
+                    vmin = None
+
+                if vmax != '':
+                    vmax = float(vmax)
+                else:
+                    vmax = None
+
+                filter_arg = (vmin, vmax)
+            except:
+                raise NexusProcessingException(
+                    reason='Invalid filter arg, must be of format 
[number],[number]',
+                    code=400
+                )
+        else:
+            filter_arg = (None, None)
+
+        render_type = compute_options.get_argument('renderType', 
'scatter').lower()
+
+        if render_type not in ['scatter', 'peak']:
+            raise NexusProcessingException(
+                reason='renderType must be either scatter or peak',
+                code=400
+            )
+
+        hide_basemap = compute_options.get_boolean_arg('hideBasemap')
+
+        return (ds, parameter_s, bounding_poly, min_elevation, max_elevation,
+                (orbit_elev, orbit_step, frame_duration, view_azim, 
view_elev), filter_arg, render_type, hide_basemap)
 
     def calc(self, computeOptions, **args):
-        (ds, parameter, bounding_poly, min_elevation, max_elevation, 
render_params) = self.parse_args(computeOptions)
+        (ds, parameter, bounding_poly, min_elevation, max_elevation, 
render_params, filter_arg, render_type,
+         hide_basemap) = (self.parse_args(computeOptions))
 
         min_lat = bounding_poly.bounds[1]
         max_lat = bounding_poly.bounds[3]
@@ -254,17 +312,36 @@ class Tomogram3D(NexusCalcHandler):
 
         logger.info(f'DataFrame:\n{df}')
 
-        return Tomogram3DResults(df, render_params, bounds=bounds)
+        return Tomogram3DResults(
+            df,
+            render_params,
+            filter_arg,
+            render_type,
+            bounds=bounds,
+            hide_basemap=hide_basemap,
+        )
 
 
 class Tomogram3DResults(NexusResults):
-    def __init__(self, results=None, render_params=None, bounds=None, 
meta=None, stats=None, computeOptions=None, status_code=200, **args):
+    def __init__(self, results=None, render_params=None, filter_args=None, 
render_type='scatter',
+                 bounds=None, meta=None, stats=None, computeOptions=None, 
status_code=200, **args):
         NexusResults.__init__(self, results, meta, stats, computeOptions, 
status_code, **args)
         self.render_params = render_params
         self.bounds = bounds
 
+        if filter_args is None:
+            self.filter = (None, None)
+        else:
+            self.filter = filter_args
+
+        self.render_type = render_type
+
+        self.hide_basemap = 'hide_basemap' in args and args['hide_basemap']
+
     def results(self):
         r: pd.DataFrame = NexusResults.results(self)
+        r = Tomogram3DResults.min_max_filter(r, *self.filter)
+
         return r
 
     def __common(self):
@@ -273,76 +350,137 @@ class Tomogram3DResults(NexusResults):
         fig = plt.figure(figsize=(10,7))
         return xyz, (fig, fig.add_subplot(111, projection='3d'))
 
+    @staticmethod
+    def min_max_filter(df: pd.DataFrame, vmin=None, vmax=None):
+        if vmin is not None or vmax is not None:
+            n_points = len(df)
+
+            if vmin is not None:
+                df = df[df['tomo_value'] >= vmin]
+
+            if vmax is not None:
+                df = df[df['tomo_value'] <= vmax]
+
+            logger.info(f'Filtered data from {n_points:,} to {len(df):,} 
points')
+
+        return df
+
     def toImage(self):
         _, _, _, view_azim, view_elev = self.render_params
 
-        xyz, (fig, ax) = self.__common()
+        results = self.results()
+        xyz = results[['lon', 'lat', 'elevation']].values
+
+        fig = plt.figure(figsize=(10, 7))
+        ax = fig.add_subplot(111, projection='3d')
 
         ax.view_init(elev=view_elev, azim=view_azim)
 
-        min_lat, min_lon, max_lat, max_lon = self.bounds
+        if not self.hide_basemap:
+            min_lat, min_lon, max_lat, max_lon = self.bounds
 
-        m = Basemap(llcrnrlon=min_lon, llcrnrlat=min_lat, urcrnrlat=max_lat, 
urcrnrlon=max_lon,)
+            m = Basemap(llcrnrlon=min_lon, llcrnrlat=min_lat, 
urcrnrlat=max_lat, urcrnrlon=max_lon,)
 
-        basemap_size = 512
+            basemap_size = 512
 
-        params = dict(
-            bbox=f'{min_lon},{min_lat},{max_lon},{max_lat}',
-            bboxSR=4326, imageSR=4326,
-            size=f'{basemap_size},{int(m.aspect * basemap_size)}',
-            dpi=2000,
-            format='png32',
-            transparent=True,
-            f='image'
-        )
+            params = dict(
+                bbox=f'{min_lon},{min_lat},{max_lon},{max_lat}',
+                bboxSR=4326, imageSR=4326,
+                size=f'{basemap_size},{int(m.aspect * basemap_size)}',
+                dpi=2000,
+                format='png32',
+                transparent=True,
+                f='image'
+            )
 
-        url = 
'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/export'
+            url = 
'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/export'
 
-        logger.info('Pulling basemap')
+            logger.info('Pulling basemap')
 
-        try:
-            elevations = self.results()[['elevation']].values
-            z_range = np.nanmax(elevations) - np.nanmin(elevations)
-            z_coord = np.nanmin(elevations) - (0.1 * z_range)
+            try:
+                elevations = results[['elevation']].values
+                z_range = np.nanmax(elevations) - np.nanmin(elevations)
+                z_coord = np.nanmin(elevations) - (0.1 * z_range)
 
-            r = requests.get(url, params=params)
-            r.raise_for_status()
+                r = requests.get(url, params=params)
+                r.raise_for_status()
 
-            buf = BytesIO(r.content)
+                buf = BytesIO(r.content)
 
-            img = Image.open(buf)
-            img_data = np.array(img)
+                img = Image.open(buf)
+                img_data = np.array(img)
 
-            lats = np.linspace(min_lat, max_lat, num=img.height)
-            lons = np.linspace(min_lon, max_lon, num=img.width)
+                lats = np.linspace(min_lat, max_lat, num=img.height)
+                lons = np.linspace(min_lon, max_lon, num=img.width)
 
-            X, Y = np.meshgrid(lons, lats)
-            Z = np.full(X.shape, z_coord)
+                X, Y = np.meshgrid(lons, lats)
+                Z = np.full(X.shape, z_coord)
 
-            ax.plot_surface(X, Y, Z, rstride=1, cstride=1, facecolors=img_data 
/ 255)
-        except:
-            logger.error('Failed to pull basemap, will not draw it')
-
-        logger.info('Plotting data')
-
-        s = ax.scatter(
-            xyz[:, 0], xyz[:, 1], xyz[:, 2],
-            marker='D',
-            facecolors=self.results()[['red', 'green', 
'blue']].values.astype(np.uint8) / 255,
-            c=self.results()[['tomo_value']].values,
-            zdir='z',
-            depthshade=True,
-            cmap=mpl.colormaps['viridis'],
-            vmin=-30, vmax=-10
-        )
+                logger.info('Plotting basemap')
+                ax.plot_surface(X, Y, Z, rstride=1, cstride=1, 
facecolors=img_data / 255)
+            except:
+                logger.error('Failed to pull basemap, will not draw it')
+
+        if self.render_type == 'scatter':
+            logger.info('Plotting tomogram data')
+            s = ax.scatter(
+                xyz[:, 0], xyz[:, 1], xyz[:, 2],
+                marker='D',
+                facecolors=results[['red', 'green', 
'blue']].values.astype(np.uint8) / 255,
+                c=results[['tomo_value']].values,
+                zdir='z',
+                depthshade=True,
+                cmap=mpl.colormaps['viridis'],
+                vmin=-30, vmax=-10
+            )
+
+            cbar = fig.colorbar(s, ax=ax)
+            cbar.set_label('Tomogram (dB)')
+        else:
+            logger.info('Collecting data by lat/lon')
+            lats = np.unique(results['lat'].values)
+            lons = np.unique(results['lon'].values)
+
+            data_dict = {}
+
+            for r in results.itertuples(index=False):
+                key = (r.lon, r.lat)
+
+                if key not in data_dict:
+                    data_dict[key] = ([r.elevation], [r.tomo_value])
+                else:
+                    data_dict[key][0].append(r.elevation)
+                    data_dict[key][1].append(r.tomo_value)
+
+            logger.info('Determining peaks')
+
+            vals = np.empty((len(lats), len(lons)))
+
+            for i, lat in enumerate(lats):
+                for j, lon in enumerate(lons):
+                    if (lon, lat) in data_dict:
+                        elevs, tomo_vals = data_dict[(lon, lat)]
+
+                        i_max = np.argmax(np.array(tomo_vals))
+                        vals[i, j] = elevs[i_max]
+                    else:
+                        vals[i, j] = np.nan
+
+            X2, Y2 = np.meshgrid(lons, lats)
+
+            logger.info('Plotting peak surface')
+            s = ax.plot_surface(
+                X2,
+                Y2,
+                vals,
+                rstride=1, cstride=1,
+                color='xkcd:leaf'
+            )
 
         ax.set_ylabel('Latitude')
         ax.set_xlabel('Longitude')
         ax.set_zlabel('Elevation w.r.t. dataset reference (m)')
 
-        cbar = fig.colorbar(s, ax=ax)
-        cbar.set_label('Tomogram (dB)')
-
         plt.tight_layout()
 
         buffer = BytesIO()
@@ -356,73 +494,119 @@ class Tomogram3DResults(NexusResults):
     def toGif(self):
         orbit_elev, orbit_step, frame_duration, _, _ = self.render_params
 
-        xyz, (fig, ax) = self.__common()
+        results = self.results()
+        xyz = results[['lon', 'lat', 'elevation']].values
+
+        fig = plt.figure(figsize=(10, 7))
+        ax = fig.add_subplot(111, projection='3d')
 
         ax.view_init(elev=orbit_elev, azim=0)
 
-        min_lat, min_lon, max_lat, max_lon = self.bounds
+        if not self.hide_basemap:
+            min_lat, min_lon, max_lat, max_lon = self.bounds
 
-        m = Basemap(llcrnrlon=min_lon, llcrnrlat=min_lat, urcrnrlat=max_lat, 
urcrnrlon=max_lon, )
+            m = Basemap(llcrnrlon=min_lon, llcrnrlat=min_lat, 
urcrnrlat=max_lat, urcrnrlon=max_lon, )
 
-        basemap_size = 512
+            basemap_size = 512
 
-        params = dict(
-            bbox=f'{min_lon},{min_lat},{max_lon},{max_lat}',
-            bboxSR=4326, imageSR=4326,
-            size=f'{basemap_size},{int(m.aspect * basemap_size)}',
-            dpi=2000,
-            format='png32',
-            transparent=True,
-            f='image'
-        )
+            params = dict(
+                bbox=f'{min_lon},{min_lat},{max_lon},{max_lat}',
+                bboxSR=4326, imageSR=4326,
+                size=f'{basemap_size},{int(m.aspect * basemap_size)}',
+                dpi=2000,
+                format='png32',
+                transparent=True,
+                f='image'
+            )
 
-        url = 
'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/export'
+            url = 
'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/export'
 
-        logger.info('Pulling basemap')
+            logger.info('Pulling basemap')
 
-        try:
-            elevations = self.results()[['elevation']].values
-            z_range = np.nanmax(elevations) - np.nanmin(elevations)
-            z_coord = np.nanmin(elevations) - (0.1 * z_range)
+            try:
+                elevations = results[['elevation']].values
+                z_range = np.nanmax(elevations) - np.nanmin(elevations)
+                z_coord = np.nanmin(elevations) - (0.1 * z_range)
 
-            r = requests.get(url, params=params)
-            r.raise_for_status()
+                r = requests.get(url, params=params)
+                r.raise_for_status()
 
-            buf = BytesIO(r.content)
+                buf = BytesIO(r.content)
 
-            img = Image.open(buf)
-            img_data = np.array(img)
+                img = Image.open(buf)
+                img_data = np.array(img)
 
-            lats = np.linspace(min_lat, max_lat, num=img.height)
-            lons = np.linspace(min_lon, max_lon, num=img.width)
+                lats = np.linspace(min_lat, max_lat, num=img.height)
+                lons = np.linspace(min_lon, max_lon, num=img.width)
 
-            X, Y = np.meshgrid(lons, lats)
-            Z = np.full(X.shape, z_coord)
+                X, Y = np.meshgrid(lons, lats)
+                Z = np.full(X.shape, z_coord)
 
-            ax.plot_surface(X, Y, Z, rstride=1, cstride=1, facecolors=img_data 
/ 255)
-        except:
-            logger.error('Failed to pull basemap, will not draw it')
-
-        logger.info('Plotting data')
-
-        s = ax.scatter(
-            xyz[:, 0], xyz[:, 1], xyz[:, 2],
-            marker='D',
-            facecolors=self.results()[['red', 'green', 
'blue']].values.astype(np.uint8) / 255,
-            c=self.results()[['tomo_value']].values,
-            zdir='z',
-            depthshade=True,
-            cmap=mpl.colormaps['viridis'],
-            vmin=-30, vmax=-10
-        )
+                logger.info('Plotting basemap')
+                ax.plot_surface(X, Y, Z, rstride=1, cstride=1, 
facecolors=img_data / 255)
+            except:
+                logger.error('Failed to pull basemap, will not draw it')
+
+        if self.render_type == 'scatter':
+            logger.info('Plotting tomogram data')
+            s = ax.scatter(
+                xyz[:, 0], xyz[:, 1], xyz[:, 2],
+                marker='D',
+                facecolors=results[['red', 'green', 
'blue']].values.astype(np.uint8) / 255,
+                c=results[['tomo_value']].values,
+                zdir='z',
+                depthshade=True,
+                cmap=mpl.colormaps['viridis'],
+                vmin=-30, vmax=-10
+            )
+
+            cbar = fig.colorbar(s, ax=ax)
+            cbar.set_label('Tomogram (dB)')
+        else:
+            logger.info('Collecting data by lat/lon')
+            lats = np.unique(results['lat'].values)
+            lons = np.unique(results['lon'].values)
+
+            data_dict = {}
+
+            for r in results.itertuples(index=False):
+                key = (r.lon, r.lat)
+
+                if key not in data_dict:
+                    data_dict[key] = ([r.elevation], [r.tomo_value])
+                else:
+                    data_dict[key][0].append(r.elevation)
+                    data_dict[key][1].append(r.tomo_value)
+
+            logger.info('Determining peaks')
+
+            vals = np.empty((len(lats), len(lons)))
+
+            for i, lat in enumerate(lats):
+                for j, lon in enumerate(lons):
+                    if (lon, lat) in data_dict:
+                        elevs, tomo_vals = data_dict[(lon, lat)]
+
+                        i_max = np.argmax(np.array(tomo_vals))
+                        vals[i, j] = elevs[i_max]
+                    else:
+                        vals[i, j] = np.nan
+
+            X2, Y2 = np.meshgrid(lons, lats)
+
+            logger.info('Plotting peak surface')
+            ax.plot_surface(
+                X2,
+                Y2,
+                vals,
+                rstride=1, cstride=1,
+                color='xkcd:leaf'
+            )
 
         ax.set_ylabel('Latitude')
         ax.set_xlabel('Longitude')
         ax.set_zlabel('Elevation w.r.t. dataset reference (m)')
 
-        cbar = fig.colorbar(s, ax=ax)
-        cbar.set_label('Tomogram (dB)')
-
         plt.tight_layout()
 
         buffer = BytesIO()

Reply via email to