This is an automated email from the ASF dual-hosted git repository. rkk pushed a commit to branch tmp-stv in repository https://gitbox.apache.org/repos/asf/sdap-nexus.git
commit b45846e1cf16953177735ca5844b4e23e640b4ad Author: rileykk <[email protected]> AuthorDate: Wed Dec 20 09:29:24 2023 -0800 Additional arguments and render types --- analysis/webservice/algorithms/Tomogram3D.py | 382 ++++++++++++++++++++------- 1 file changed, 283 insertions(+), 99 deletions(-) diff --git a/analysis/webservice/algorithms/Tomogram3D.py b/analysis/webservice/algorithms/Tomogram3D.py index 1bdccb5..cc3d97c 100644 --- a/analysis/webservice/algorithms/Tomogram3D.py +++ b/analysis/webservice/algorithms/Tomogram3D.py @@ -24,6 +24,7 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd import requests +from matplotlib.colors import XKCD_COLORS from mpl_toolkits.basemap import Basemap from PIL import Image from webservice.NexusHandler import nexus_handler @@ -88,6 +89,25 @@ class Tomogram3D(NexusCalcHandler): "description": "If output==GIF, specifies the duration of each frame in the animation in milliseconds. " "Default: 100; Range: >=100" }, + "filter": { + "name": "Voxel filter", + "type": "comma-delimited pair of numbers", + "description": "Pair of numbers min,max. If defined, will filter out all points in the tomogram whose value" + " does not satisfy min <= v <= max. Can be left unbounded, eg: 'filter=-15,' would filter " + "everything but points >= -15" + }, + "renderType": { + "name": "Render Type", + "type": "string", + "description": "Type of render: Must be either \"scatter\" or \"peak\". Scatter will plot every voxel in " + "the tomogram, colored by value, peak will plot a surface of all peaks for each lat/lon " + "grid point, flat colored and shaded. Default: scatter" + }, + "hideBasemap": { + "name": "Hide Basemap", + "type": "boolean", + "description": "If true, do not draw basemap beneath plot. This can be used to speed up render time a bit" + }, } singleton = True @@ -161,11 +181,49 @@ class Tomogram3D(NexusCalcHandler): orbit_elev, orbit_step = None, None - return ds, parameter_s, bounding_poly, min_elevation, max_elevation, (orbit_elev, orbit_step, frame_duration, - view_azim, view_elev) + filter_arg = compute_options.get_argument('filter') + + if filter_arg is not None: + try: + filter_arg = filter_arg.split(',') + assert len(filter_arg) == 2 + vmin, vmax = filter_arg + + if vmin != '': + vmin = float(vmin) + else: + vmin = None + + if vmax != '': + vmax = float(vmax) + else: + vmax = None + + filter_arg = (vmin, vmax) + except: + raise NexusProcessingException( + reason='Invalid filter arg, must be of format [number],[number]', + code=400 + ) + else: + filter_arg = (None, None) + + render_type = compute_options.get_argument('renderType', 'scatter').lower() + + if render_type not in ['scatter', 'peak']: + raise NexusProcessingException( + reason='renderType must be either scatter or peak', + code=400 + ) + + hide_basemap = compute_options.get_boolean_arg('hideBasemap') + + return (ds, parameter_s, bounding_poly, min_elevation, max_elevation, + (orbit_elev, orbit_step, frame_duration, view_azim, view_elev), filter_arg, render_type, hide_basemap) def calc(self, computeOptions, **args): - (ds, parameter, bounding_poly, min_elevation, max_elevation, render_params) = self.parse_args(computeOptions) + (ds, parameter, bounding_poly, min_elevation, max_elevation, render_params, filter_arg, render_type, + hide_basemap) = (self.parse_args(computeOptions)) min_lat = bounding_poly.bounds[1] max_lat = bounding_poly.bounds[3] @@ -254,17 +312,36 @@ class Tomogram3D(NexusCalcHandler): logger.info(f'DataFrame:\n{df}') - return Tomogram3DResults(df, render_params, bounds=bounds) + return Tomogram3DResults( + df, + render_params, + filter_arg, + render_type, + bounds=bounds, + hide_basemap=hide_basemap, + ) class Tomogram3DResults(NexusResults): - def __init__(self, results=None, render_params=None, bounds=None, meta=None, stats=None, computeOptions=None, status_code=200, **args): + def __init__(self, results=None, render_params=None, filter_args=None, render_type='scatter', + bounds=None, meta=None, stats=None, computeOptions=None, status_code=200, **args): NexusResults.__init__(self, results, meta, stats, computeOptions, status_code, **args) self.render_params = render_params self.bounds = bounds + if filter_args is None: + self.filter = (None, None) + else: + self.filter = filter_args + + self.render_type = render_type + + self.hide_basemap = 'hide_basemap' in args and args['hide_basemap'] + def results(self): r: pd.DataFrame = NexusResults.results(self) + r = Tomogram3DResults.min_max_filter(r, *self.filter) + return r def __common(self): @@ -273,76 +350,137 @@ class Tomogram3DResults(NexusResults): fig = plt.figure(figsize=(10,7)) return xyz, (fig, fig.add_subplot(111, projection='3d')) + @staticmethod + def min_max_filter(df: pd.DataFrame, vmin=None, vmax=None): + if vmin is not None or vmax is not None: + n_points = len(df) + + if vmin is not None: + df = df[df['tomo_value'] >= vmin] + + if vmax is not None: + df = df[df['tomo_value'] <= vmax] + + logger.info(f'Filtered data from {n_points:,} to {len(df):,} points') + + return df + def toImage(self): _, _, _, view_azim, view_elev = self.render_params - xyz, (fig, ax) = self.__common() + results = self.results() + xyz = results[['lon', 'lat', 'elevation']].values + + fig = plt.figure(figsize=(10, 7)) + ax = fig.add_subplot(111, projection='3d') ax.view_init(elev=view_elev, azim=view_azim) - min_lat, min_lon, max_lat, max_lon = self.bounds + if not self.hide_basemap: + min_lat, min_lon, max_lat, max_lon = self.bounds - m = Basemap(llcrnrlon=min_lon, llcrnrlat=min_lat, urcrnrlat=max_lat, urcrnrlon=max_lon,) + m = Basemap(llcrnrlon=min_lon, llcrnrlat=min_lat, urcrnrlat=max_lat, urcrnrlon=max_lon,) - basemap_size = 512 + basemap_size = 512 - params = dict( - bbox=f'{min_lon},{min_lat},{max_lon},{max_lat}', - bboxSR=4326, imageSR=4326, - size=f'{basemap_size},{int(m.aspect * basemap_size)}', - dpi=2000, - format='png32', - transparent=True, - f='image' - ) + params = dict( + bbox=f'{min_lon},{min_lat},{max_lon},{max_lat}', + bboxSR=4326, imageSR=4326, + size=f'{basemap_size},{int(m.aspect * basemap_size)}', + dpi=2000, + format='png32', + transparent=True, + f='image' + ) - url = 'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/export' + url = 'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/export' - logger.info('Pulling basemap') + logger.info('Pulling basemap') - try: - elevations = self.results()[['elevation']].values - z_range = np.nanmax(elevations) - np.nanmin(elevations) - z_coord = np.nanmin(elevations) - (0.1 * z_range) + try: + elevations = results[['elevation']].values + z_range = np.nanmax(elevations) - np.nanmin(elevations) + z_coord = np.nanmin(elevations) - (0.1 * z_range) - r = requests.get(url, params=params) - r.raise_for_status() + r = requests.get(url, params=params) + r.raise_for_status() - buf = BytesIO(r.content) + buf = BytesIO(r.content) - img = Image.open(buf) - img_data = np.array(img) + img = Image.open(buf) + img_data = np.array(img) - lats = np.linspace(min_lat, max_lat, num=img.height) - lons = np.linspace(min_lon, max_lon, num=img.width) + lats = np.linspace(min_lat, max_lat, num=img.height) + lons = np.linspace(min_lon, max_lon, num=img.width) - X, Y = np.meshgrid(lons, lats) - Z = np.full(X.shape, z_coord) + X, Y = np.meshgrid(lons, lats) + Z = np.full(X.shape, z_coord) - ax.plot_surface(X, Y, Z, rstride=1, cstride=1, facecolors=img_data / 255) - except: - logger.error('Failed to pull basemap, will not draw it') - - logger.info('Plotting data') - - s = ax.scatter( - xyz[:, 0], xyz[:, 1], xyz[:, 2], - marker='D', - facecolors=self.results()[['red', 'green', 'blue']].values.astype(np.uint8) / 255, - c=self.results()[['tomo_value']].values, - zdir='z', - depthshade=True, - cmap=mpl.colormaps['viridis'], - vmin=-30, vmax=-10 - ) + logger.info('Plotting basemap') + ax.plot_surface(X, Y, Z, rstride=1, cstride=1, facecolors=img_data / 255) + except: + logger.error('Failed to pull basemap, will not draw it') + + if self.render_type == 'scatter': + logger.info('Plotting tomogram data') + s = ax.scatter( + xyz[:, 0], xyz[:, 1], xyz[:, 2], + marker='D', + facecolors=results[['red', 'green', 'blue']].values.astype(np.uint8) / 255, + c=results[['tomo_value']].values, + zdir='z', + depthshade=True, + cmap=mpl.colormaps['viridis'], + vmin=-30, vmax=-10 + ) + + cbar = fig.colorbar(s, ax=ax) + cbar.set_label('Tomogram (dB)') + else: + logger.info('Collecting data by lat/lon') + lats = np.unique(results['lat'].values) + lons = np.unique(results['lon'].values) + + data_dict = {} + + for r in results.itertuples(index=False): + key = (r.lon, r.lat) + + if key not in data_dict: + data_dict[key] = ([r.elevation], [r.tomo_value]) + else: + data_dict[key][0].append(r.elevation) + data_dict[key][1].append(r.tomo_value) + + logger.info('Determining peaks') + + vals = np.empty((len(lats), len(lons))) + + for i, lat in enumerate(lats): + for j, lon in enumerate(lons): + if (lon, lat) in data_dict: + elevs, tomo_vals = data_dict[(lon, lat)] + + i_max = np.argmax(np.array(tomo_vals)) + vals[i, j] = elevs[i_max] + else: + vals[i, j] = np.nan + + X2, Y2 = np.meshgrid(lons, lats) + + logger.info('Plotting peak surface') + s = ax.plot_surface( + X2, + Y2, + vals, + rstride=1, cstride=1, + color='xkcd:leaf' + ) ax.set_ylabel('Latitude') ax.set_xlabel('Longitude') ax.set_zlabel('Elevation w.r.t. dataset reference (m)') - cbar = fig.colorbar(s, ax=ax) - cbar.set_label('Tomogram (dB)') - plt.tight_layout() buffer = BytesIO() @@ -356,73 +494,119 @@ class Tomogram3DResults(NexusResults): def toGif(self): orbit_elev, orbit_step, frame_duration, _, _ = self.render_params - xyz, (fig, ax) = self.__common() + results = self.results() + xyz = results[['lon', 'lat', 'elevation']].values + + fig = plt.figure(figsize=(10, 7)) + ax = fig.add_subplot(111, projection='3d') ax.view_init(elev=orbit_elev, azim=0) - min_lat, min_lon, max_lat, max_lon = self.bounds + if not self.hide_basemap: + min_lat, min_lon, max_lat, max_lon = self.bounds - m = Basemap(llcrnrlon=min_lon, llcrnrlat=min_lat, urcrnrlat=max_lat, urcrnrlon=max_lon, ) + m = Basemap(llcrnrlon=min_lon, llcrnrlat=min_lat, urcrnrlat=max_lat, urcrnrlon=max_lon, ) - basemap_size = 512 + basemap_size = 512 - params = dict( - bbox=f'{min_lon},{min_lat},{max_lon},{max_lat}', - bboxSR=4326, imageSR=4326, - size=f'{basemap_size},{int(m.aspect * basemap_size)}', - dpi=2000, - format='png32', - transparent=True, - f='image' - ) + params = dict( + bbox=f'{min_lon},{min_lat},{max_lon},{max_lat}', + bboxSR=4326, imageSR=4326, + size=f'{basemap_size},{int(m.aspect * basemap_size)}', + dpi=2000, + format='png32', + transparent=True, + f='image' + ) - url = 'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/export' + url = 'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/export' - logger.info('Pulling basemap') + logger.info('Pulling basemap') - try: - elevations = self.results()[['elevation']].values - z_range = np.nanmax(elevations) - np.nanmin(elevations) - z_coord = np.nanmin(elevations) - (0.1 * z_range) + try: + elevations = results[['elevation']].values + z_range = np.nanmax(elevations) - np.nanmin(elevations) + z_coord = np.nanmin(elevations) - (0.1 * z_range) - r = requests.get(url, params=params) - r.raise_for_status() + r = requests.get(url, params=params) + r.raise_for_status() - buf = BytesIO(r.content) + buf = BytesIO(r.content) - img = Image.open(buf) - img_data = np.array(img) + img = Image.open(buf) + img_data = np.array(img) - lats = np.linspace(min_lat, max_lat, num=img.height) - lons = np.linspace(min_lon, max_lon, num=img.width) + lats = np.linspace(min_lat, max_lat, num=img.height) + lons = np.linspace(min_lon, max_lon, num=img.width) - X, Y = np.meshgrid(lons, lats) - Z = np.full(X.shape, z_coord) + X, Y = np.meshgrid(lons, lats) + Z = np.full(X.shape, z_coord) - ax.plot_surface(X, Y, Z, rstride=1, cstride=1, facecolors=img_data / 255) - except: - logger.error('Failed to pull basemap, will not draw it') - - logger.info('Plotting data') - - s = ax.scatter( - xyz[:, 0], xyz[:, 1], xyz[:, 2], - marker='D', - facecolors=self.results()[['red', 'green', 'blue']].values.astype(np.uint8) / 255, - c=self.results()[['tomo_value']].values, - zdir='z', - depthshade=True, - cmap=mpl.colormaps['viridis'], - vmin=-30, vmax=-10 - ) + logger.info('Plotting basemap') + ax.plot_surface(X, Y, Z, rstride=1, cstride=1, facecolors=img_data / 255) + except: + logger.error('Failed to pull basemap, will not draw it') + + if self.render_type == 'scatter': + logger.info('Plotting tomogram data') + s = ax.scatter( + xyz[:, 0], xyz[:, 1], xyz[:, 2], + marker='D', + facecolors=results[['red', 'green', 'blue']].values.astype(np.uint8) / 255, + c=results[['tomo_value']].values, + zdir='z', + depthshade=True, + cmap=mpl.colormaps['viridis'], + vmin=-30, vmax=-10 + ) + + cbar = fig.colorbar(s, ax=ax) + cbar.set_label('Tomogram (dB)') + else: + logger.info('Collecting data by lat/lon') + lats = np.unique(results['lat'].values) + lons = np.unique(results['lon'].values) + + data_dict = {} + + for r in results.itertuples(index=False): + key = (r.lon, r.lat) + + if key not in data_dict: + data_dict[key] = ([r.elevation], [r.tomo_value]) + else: + data_dict[key][0].append(r.elevation) + data_dict[key][1].append(r.tomo_value) + + logger.info('Determining peaks') + + vals = np.empty((len(lats), len(lons))) + + for i, lat in enumerate(lats): + for j, lon in enumerate(lons): + if (lon, lat) in data_dict: + elevs, tomo_vals = data_dict[(lon, lat)] + + i_max = np.argmax(np.array(tomo_vals)) + vals[i, j] = elevs[i_max] + else: + vals[i, j] = np.nan + + X2, Y2 = np.meshgrid(lons, lats) + + logger.info('Plotting peak surface') + ax.plot_surface( + X2, + Y2, + vals, + rstride=1, cstride=1, + color='xkcd:leaf' + ) ax.set_ylabel('Latitude') ax.set_xlabel('Longitude') ax.set_zlabel('Elevation w.r.t. dataset reference (m)') - cbar = fig.colorbar(s, ax=ax) - cbar.set_label('Tomogram (dB)') - plt.tight_layout() buffer = BytesIO()
