You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sdap.apache.org by sk...@apache.org on 2022/06/11 00:38:58 UTC

[incubator-sdap-nexus] 01/02: Fixed DOMS subsetter for satellite data

This is an automated email from the ASF dual-hosted git repository.

skperez pushed a commit to branch SDAP-371
in repository https://gitbox.apache.org/repos/asf/incubator-sdap-nexus.git

commit 8611706e633f098b113e33b80e4b008155eb1555
Author: skorper <st...@gmail.com>
AuthorDate: Fri Jun 10 17:36:31 2022 -0700

    Fixed DOMS subsetter for satellite data
---
 analysis/webservice/algorithms/doms/subsetter.py | 104 ++++++++++++++++++++---
 data-access/nexustiles/model/nexusmodel.py       |  20 +++--
 data-access/nexustiles/nexustiles.py             |  12 ++-
 3 files changed, 119 insertions(+), 17 deletions(-)

diff --git a/analysis/webservice/algorithms/doms/subsetter.py b/analysis/webservice/algorithms/doms/subsetter.py
index 4c1ff97..ed1f552 100644
--- a/analysis/webservice/algorithms/doms/subsetter.py
+++ b/analysis/webservice/algorithms/doms/subsetter.py
@@ -15,17 +15,20 @@
 
 import logging
 import os
+import io
 import tempfile
 import zipfile
+from pytz import timezone
 from datetime import datetime
 
 import requests
 
 from . import BaseDomsHandler
 from webservice.NexusHandler import nexus_handler
-from webservice.webmodel import NexusProcessingException
+from webservice.webmodel import NexusProcessingException, NexusResults
 
 ISO_8601 = '%Y-%m-%dT%H:%M:%S%z'
+EPOCH = timezone('UTC').localize(datetime(1970, 1, 1))
 
 
 def is_blank(my_string):
@@ -121,14 +124,14 @@ class DomsResultsRetrievalHandler(BaseDomsHandler.BaseDomsQueryCalcHandler):
 
         try:
             start_time = request.get_start_datetime()
-            start_time = start_time.strftime("%Y-%m-%dT%H:%M:%SZ")
+            start_time = int((start_time - EPOCH).total_seconds())
         except:
             raise NexusProcessingException(
                 reason="'startTime' argument is required. Can be int value seconds from epoch or string format YYYY-MM-DDTHH:mm:ssZ",
                 code=400)
         try:
             end_time = request.get_end_datetime()
-            end_time = end_time.strftime("%Y-%m-%dT%H:%M:%SZ")
+            end_time = int((end_time - EPOCH).total_seconds())
         except:
             raise NexusProcessingException(
                 reason="'endTime' argument is required. Can be int value seconds from epoch or string format YYYY-MM-DDTHH:mm:ssZ",
@@ -167,6 +170,55 @@ class DomsResultsRetrievalHandler(BaseDomsHandler.BaseDomsQueryCalcHandler):
                bounding_polygon, depth_min, depth_max, platforms
 
     def calc(self, request, **args):
+        primary_ds_name, matchup_ds_names, parameter_s, start_time, end_time, \
+        bounding_polygon, depth_min, depth_max, platforms = self.parse_arguments(request)
+
+        min_lat = max_lat = min_lon = max_lon = None
+        if bounding_polygon:
+            min_lat = bounding_polygon.bounds[1]
+            max_lat = bounding_polygon.bounds[3]
+            min_lon = bounding_polygon.bounds[0]
+            max_lon = bounding_polygon.bounds[2]
+
+            tiles = self._get_tile_service().get_tiles_bounded_by_box(min_lat, max_lat, min_lon,
+                                                                      max_lon, primary_ds_name, start_time,
+                                                                      end_time)
+        else:
+            tiles = []  # todo
+            # tiles = self._get_tile_service().get_tiles_by_metadata(metadata_filter, ds, start_time,
+            #                                                        end_time)
+
+        data = []
+        for tile in tiles:
+            for nexus_point in tile.nexus_point_generator():
+                if tile.is_multi:
+                    data_points = {
+                        tile.variables[idx].standard_name: nexus_point.data_vals[idx]
+                        for idx in range(len(tile.variables))
+                    }
+                else:
+                    data_points = {tile.variables[0].standard_name: nexus_point.data_vals}
+                data.append({
+                    'latitude': nexus_point.latitude,
+                    'longitude': nexus_point.longitude,
+                    'time': nexus_point.time,
+                    'data': data_points
+                })
+        if len(tiles) > 0:
+            meta = [tile.get_summary() for tile in tiles]
+        else:
+            meta = None
+
+        result = SubsetResult(
+            results=data,
+            meta=meta
+        )
+
+        result.extendMeta(min_lat, max_lat, min_lon, max_lon, "", start_time, end_time)
+
+        return result
+
+    def calc2(self, request, **args):
 
         primary_ds_name, matchup_ds_names, parameter_s, start_time, end_time, \
         bounding_polygon, depth_min, depth_max, platforms = self.parse_arguments(request)
@@ -235,18 +287,50 @@ class DomsResultsRetrievalHandler(BaseDomsHandler.BaseDomsQueryCalcHandler):
         return SubsetResult(zip_path)
 
 
-class SubsetResult(object):
-    def __init__(self, zip_path):
-        self.zip_path = zip_path
-
+class SubsetResult(NexusResults):
     def toJson(self):
         raise NotImplementedError
 
+    def toCsv(self):
+        """
+        Convert results to CSV
+        """
+        rows = []
+
+        headers = [
+            'longitude',
+            'latitude',
+            'time'
+        ]
+
+        results = self.results()
+
+        data_variables = set([keys for result in results for keys in result['data'].keys()])
+        headers.extend(data_variables)
+        for i, result in enumerate(results):
+            cols = []
+
+            cols.append(result['longitude'])
+            cols.append(result['latitude'])
+            cols.append(datetime.utcfromtimestamp(result['time']).strftime('%Y-%m-%dT%H:%M:%SZ'))
+
+            for var in data_variables:
+                cols.append(result['data'][var])
+            if i == 0:
+                rows.append(','.join(headers))
+            rows.append(','.join(map(str, cols)))
+
+        return "\r\n".join(rows)
+
     def toZip(self):
-        with open(self.zip_path, 'rb') as zip_file:
-            zip_contents = zip_file.read()
+        csv_contents = self.toCsv()
+
+        buffer = io.BytesIO()
+        with zipfile.ZipFile(buffer, 'a', zipfile.ZIP_DEFLATED) as zip_file:
+            zip_file.writestr('result.csv', csv_contents)
 
-        return zip_contents
+        buffer.seek(0)
+        return buffer.read()
 
     def cleanup(self):
         os.remove(self.zip_path)
diff --git a/data-access/nexustiles/model/nexusmodel.py b/data-access/nexustiles/model/nexusmodel.py
index 753d264..f5c9df6 100644
--- a/data-access/nexustiles/model/nexusmodel.py
+++ b/data-access/nexustiles/model/nexusmodel.py
@@ -126,22 +126,30 @@ class Tile(object):
         return summary
 
     def nexus_point_generator(self, include_nan=False):
+        indices = self.get_indices(include_nan)
+
         if include_nan:
-            for index in np.ndindex(self.data.shape):
+            for index in indices:
                 time = self.times[index[0]]
                 lat = self.latitudes[index[1]]
                 lon = self.longitudes[index[2]]
-                data_val = self.data[index]
-                point = NexusPoint(lat, lon, None, time, index, data_val)
+                if self.is_multi:
+                    data_vals = [data[index] for data in self.data]
+                else:
+                    data_vals = self.data[index]
+                point = NexusPoint(lat, lon, None, time, index, data_vals)
                 yield point
         else:
-            for index in np.transpose(np.ma.nonzero(self.data)):
+            for index in indices:
                 index = tuple(index)
                 time = self.times[index[0]]
                 lat = self.latitudes[index[1]]
                 lon = self.longitudes[index[2]]
-                data_val = self.data[index]
-                point = NexusPoint(lat, lon, None, time, index, data_val)
+                if self.is_multi:
+                    data_vals = [data[index] for data in self.data]
+                else:
+                    data_vals = self.data[index]
+                point = NexusPoint(lat, lon, None, time, index, data_vals)
                 yield point
 
     def get_indices(self, include_nan=False):
diff --git a/data-access/nexustiles/nexustiles.py b/data-access/nexustiles/nexustiles.py
index 7483c2b..88a1687 100644
--- a/data-access/nexustiles/nexustiles.py
+++ b/data-access/nexustiles/nexustiles.py
@@ -260,6 +260,7 @@ class NexusTileService(object):
     def get_tiles_bounded_by_box(self, min_lat, max_lat, min_lon, max_lon, ds=None, start_time=0, end_time=-1,
                                  **kwargs):
         tiles = self.find_tiles_in_box(min_lat, max_lat, min_lon, max_lon, ds, start_time, end_time, **kwargs)
+        tiles = tiles[:1] # TODO REMOVE ME!!!
         tiles = self.mask_tiles_to_bbox(min_lat, max_lat, min_lon, max_lon, tiles)
         if 0 <= start_time <= end_time:
             tiles = self.mask_tiles_to_time_range(start_time, end_time, tiles)
@@ -423,7 +424,16 @@ class NexusTileService(object):
                             | ma.getmaskarray(tile.latitudes)[np.newaxis, :, np.newaxis] \
                             | ma.getmaskarray(tile.longitudes)[np.newaxis, np.newaxis, :]
 
-                tile.data = ma.masked_where(data_mask, tile.data)
+                # If this is multi-var, need to mask each variable separately.
+                if tile.is_multi:
+                    # Combine space/time mask with existing mask on data
+                    data_mask = reduce(np.logical_or, [tile.data[0].mask, data_mask])
+
+                    num_vars = len(tile.data)
+                    multi_data_mask = np.repeat(data_mask[np.newaxis, ...], num_vars, axis=0)
+                    tile.data = ma.masked_where(multi_data_mask, tile.data)
+                else:
+                    tile.data = ma.masked_where(data_mask, tile.data)
 
             tiles[:] = [tile for tile in tiles if not tile.data.mask.all()]