This is an automated email from the ASF dual-hosted git repository.

machristie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airavata-django-portal-sdk.git

commit 20679297e05b6580ec22ae4860be384a4a5ffddb
Author: Marcus Christie <[email protected]>
AuthorDate: Fri Nov 19 10:54:51 2021 -0500

    AIRAVATA-3542 download experiments endpoint
---
 airavata_django_portal_sdk/serializers.py | 14 +++++++++
 airavata_django_portal_sdk/urls.py        | 10 ++++---
 airavata_django_portal_sdk/views.py       | 49 ++++++++++++++++++++++++++++---
 3 files changed, 65 insertions(+), 8 deletions(-)

diff --git a/airavata_django_portal_sdk/serializers.py 
b/airavata_django_portal_sdk/serializers.py
new file mode 100644
index 0000000..2fc09b5
--- /dev/null
+++ b/airavata_django_portal_sdk/serializers.py
@@ -0,0 +1,14 @@
+from rest_framework import serializers
+
+# class DownloadIncludeSerializer(serializers.Serializer):
+#     pattern = serializers.CharField()
+
+
+class ExperimentDownloadSerializer(serializers.Serializer):
+    experiment_id = serializers.CharField()
+    # includes = DownloadIncludeSerializer(many=True)
+
+
+class MultiExperimentDownloadSerializer(serializers.Serializer):
+    experiments = ExperimentDownloadSerializer(many=True)
+    # TODO: filename parameter?
diff --git a/airavata_django_portal_sdk/urls.py 
b/airavata_django_portal_sdk/urls.py
index c59f5dd..f07c977 100644
--- a/airavata_django_portal_sdk/urls.py
+++ b/airavata_django_portal_sdk/urls.py
@@ -15,8 +15,10 @@ def get_download_url(data_product_uri):
 
 app_name = 'airavata_django_portal_sdk'
 urlpatterns = [
-    path('download-file', views.download_file, name='download_file'),
-    path('download', views.download, name='download'),
-    path('download-dir', views.download_dir, name='download_dir'),
-    path('download-experiment-dir/<experiment_id>', 
views.download_experiment_dir, name='download_experiment_dir'),
+    path('download-file/', views.download_file, name='download_file'),
+    path('download/', views.download, name='download'),
+    path('download-dir/', views.download_dir, name='download_dir'),
+    path('download-experiment-dir/<experiment_id>/', 
views.download_experiment_dir, name='download_experiment_dir'),
+    path('download-experiments/<download_id>/', views.download_experiments, 
name="download_experiments"),
+    path('download-experiments/', views.download_experiments, 
name="download_experiments"),
 ]
diff --git a/airavata_django_portal_sdk/views.py 
b/airavata_django_portal_sdk/views.py
index 9c1f0ba..0592c2d 100644
--- a/airavata_django_portal_sdk/views.py
+++ b/airavata_django_portal_sdk/views.py
@@ -1,16 +1,20 @@
 import logging
 import os
 import tempfile
+import uuid
 import zipfile
 
 from django.core.exceptions import ObjectDoesNotExist
 from django.http import FileResponse, Http404
 from django.shortcuts import redirect
+from django.urls import reverse
 from django.utils.text import get_valid_filename
 from django.views.decorators.gzip import gzip_page
+from rest_framework import status
 from rest_framework.decorators import api_view
+from rest_framework.response import Response
 
-from airavata_django_portal_sdk import user_storage
+from airavata_django_portal_sdk import serializers, user_storage
 
 logger = logging.getLogger(__name__)
 
@@ -96,6 +100,41 @@ def download_experiment_dir(request, experiment_id=None):
     return FileResponse(fp, as_attachment=True, filename=filename)
 
 
+@api_view(['GET', 'POST'])
+def download_experiments(request, download_id=None):
+    if request.method == 'POST':
+        serializer = 
serializers.MultiExperimentDownloadSerializer(data=request.data)
+        if serializer.is_valid():
+            download_id = str(uuid.uuid4())
+            request.session["download_experiments:" + download_id] = 
serializer.validated_data
+            download_url = 
reverse('airavata_django_portal_sdk:download_experiments',
+                                   args=[download_id])
+            return Response({"download_url": download_url})
+        else:
+            return Response(serializer.errors, 
status=status.HTTP_400_BAD_REQUEST)
+    elif request.method == 'GET' and download_id is not None:
+        download_key = f"download_experiments:{download_id}"
+        if download_key in request.session:
+            experiments = request.session[download_key]['experiments']
+            fp = tempfile.TemporaryFile()
+            with zipfile.ZipFile(fp, 'w', compression=zipfile.ZIP_DEFLATED) as 
zf:
+                for experiment in experiments:
+                    experiment_id = experiment['experiment_id']
+                    # Load experiment to make sure user has access to 
experiment
+                    experiment = 
request.airavata_client.getExperiment(request.authz_token, experiment_id)
+                    _add_experiment_directory_to_zipfile(request, zf, 
experiment_id, path="",
+                                                         
zipfile_prefix=get_valid_filename(experiment.experimentName))
+
+            filename = "experiments.zip"
+            fp.seek(0)
+            # FileResponse will automatically close the temporary file
+            return FileResponse(fp, as_attachment=True, filename=filename)
+        else:
+            return Response({"detail": "Not found."}, 
status=status.HTTP_404_NOT_FOUND)
+    else:
+        return Response({"detail": "Bad request"}, 
status=status.HTTP_400_BAD_REQUEST)
+
+
 def _add_directory_to_zipfile(request, zf, path, directory=""):
     directories, files = user_storage.listdir(request, os.path.join(path, 
directory))
     for file in files:
@@ -107,12 +146,14 @@ def _add_directory_to_zipfile(request, zf, path, 
directory=""):
         _add_directory_to_zipfile(request, zf, path, os.path.join(directory, 
d['name']))
 
 
-def _add_experiment_directory_to_zipfile(request, zf, experiment_id, path, 
directory=""):
+def _add_experiment_directory_to_zipfile(request, zf, experiment_id, path, 
directory="", zipfile_prefix=""):
     directories, files = user_storage.list_experiment_dir(request, 
experiment_id, os.path.join(path, directory))
     for file in files:
         o = user_storage.open_file(request, 
data_product_uri=file['data-product-uri'])
-        zf.writestr(os.path.join(directory, file['name']), o.read())
+        zf.writestr(os.path.join(zipfile_prefix, directory, file['name']), 
o.read())
         if os.path.getsize(zf.filename) > MAX_DOWNLOAD_ZIPFILE_SIZE:
             raise Exception(f"Zip file size exceeds max of 
{MAX_DOWNLOAD_ZIPFILE_SIZE} bytes")
     for d in directories:
-        _add_experiment_directory_to_zipfile(request, zf, experiment_id, path, 
os.path.join(directory, d['name']))
+        _add_experiment_directory_to_zipfile(request, zf, experiment_id, path,
+                                             directory=os.path.join(directory, 
d['name']),
+                                             zipfile_prefix=zipfile_prefix)

Reply via email to