Skip to content

Commit

Permalink
Merge pull request #32 from nansencenter/syntool_compare_profiles
Browse files Browse the repository at this point in the history
Support new task generating profile comparisons
  • Loading branch information
aperrin66 authored Oct 1, 2024
2 parents 5c8f0b9 + 4fd6d07 commit ae4277b
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 18 deletions.
11 changes: 6 additions & 5 deletions geospaas_rest_api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
geospaas_processing = None

if geospaas_processing:
from .processing_api.models import (Job,
DownloadJob,
ConvertJob,
SyntoolCleanupJob,
HarvestJob)
from geospaas_rest_api.processing_api.models import (Job,
DownloadJob,
ConvertJob,
SyntoolCleanupJob,
SyntoolCompareJob,
HarvestJob)
61 changes: 55 additions & 6 deletions geospaas_rest_api/processing_api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ def get_signature(cls, parameters):
tasks_core.crop.signature(
kwargs={'bounding_box': parameters.get('bounding_box', None)}),
])

tasks.extend([
tasks_core.archive.signature(),
tasks_core.publish.signature(),
])
if parameters.get('publish', False):
tasks.extend([
tasks_core.archive.signature(),
tasks_core.publish.signature(),
])
return celery.chain(tasks)

@staticmethod
Expand All @@ -100,7 +100,7 @@ def check_parameters(parameters):
- dataset_id: integer
- bounding_box: 4-elements list
"""
if not set(parameters).issubset(set(('dataset_id', 'bounding_box'))):
if not set(parameters).issubset(set(('dataset_id', 'bounding_box', 'publish'))):
raise ValidationError("The download action accepts only one parameter: 'dataset_id'")
if not isinstance(parameters['dataset_id'], int):
raise ValidationError("'dataset_id' must be an integer")
Expand All @@ -109,6 +109,8 @@ def check_parameters(parameters):
len(parameters['bounding_box']) == 4)):
raise ValidationError("'bounding_box' must be a sequence in the following format: "
"west, north, east, south")
if ('publish' in parameters and not isinstance(parameters['publish'], bool)):
raise ValidationError("'publish' must be a boolean")
return parameters

@staticmethod
Expand Down Expand Up @@ -220,6 +222,53 @@ def make_task_parameters(parameters):
return ((parameters['criteria'],), {})


class SyntoolCompareJob(Job):
"""Job which generates comparison between Argo profiles and a 3D
product
"""
class Meta:
proxy = True

@classmethod
def get_signature(cls, parameters):
return celery.chain(
tasks_syntool.compare_profiles.signature(),
tasks_syntool.db_insert.signature(),
)

@staticmethod
def check_parameters(parameters):
accepted_keys = ('model', 'profiles')
if not set(parameters) == set(accepted_keys):
raise ValidationError(
f"The convert action accepts only these parameters: {', '.join(accepted_keys)}")

if ((not isinstance(parameters['model'], Sequence)) or
len(parameters['model']) != 2 or
not isinstance(parameters['model'][0], int) or
not isinstance(parameters['model'][1], str)):
raise ValidationError("'model' must be a tuple (model_id, model_path)")

valid_profiles = True
if not isinstance(parameters['profiles'], Sequence):
valid_profiles = False
else:
for profile_tuple in parameters['profiles']:
if (not isinstance(profile_tuple, Sequence) or
len(profile_tuple) != 2 or
not isinstance(profile_tuple[0], int) or
not isinstance(profile_tuple[1], str)):
valid_profiles = False
break
if not valid_profiles:
raise ValidationError("'profiles' must be a list of tuples (profile_id, profile_path)")
return parameters

@staticmethod
def make_task_parameters(parameters):
return (((parameters['model'], parameters['profiles']),), {})


class HarvestJob(Job):
"""Job which harvests metadata into the database"""
class Meta:
Expand Down
4 changes: 3 additions & 1 deletion geospaas_rest_api/processing_api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class JobSerializer(rest_framework.serializers.Serializer):
'convert': models.ConvertJob,
'harvest': models.HarvestJob,
'syntool_cleanup': models.SyntoolCleanupJob,
'compare_profiles': models.SyntoolCompareJob,
}

# Actual Job fields
Expand All @@ -27,7 +28,8 @@ class JobSerializer(rest_framework.serializers.Serializer):
'download',
'convert',
'harvest',
'syntool_cleanup'
'syntool_cleanup',
'compare_profiles',
],
required=True, write_only=True,
help_text="Action to perform")
Expand Down
125 changes: 119 additions & 6 deletions geospaas_rest_api/tests/test_processing_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,16 @@ def test_get_signature_no_cropping(self):
with mock.patch('geospaas_rest_api.processing_api.models.tasks_core') as mock_tasks, \
mock.patch('celery.chain') as mock_chain:
_ = models.DownloadJob.get_signature({})
mock_chain.assert_called_with([
mock_tasks.download.signature.return_value,
mock_tasks.archive.signature.return_value,
mock_tasks.publish.signature.return_value,
_ = models.DownloadJob.get_signature({'publish': True})
mock_chain.assert_has_calls([
mock.call([
mock_tasks.download.signature.return_value,
]),
mock.call([
mock_tasks.download.signature.return_value,
mock_tasks.archive.signature.return_value,
mock_tasks.publish.signature.return_value,
]),
])

def test_get_signature_cropping(self):
Expand All @@ -200,14 +206,21 @@ def test_get_signature_cropping(self):
with mock.patch('geospaas_rest_api.processing_api.models.tasks_core') as mock_tasks, \
mock.patch('celery.chain') as mock_chain:
_ = models.DownloadJob.get_signature({'bounding_box': [0, 20, 20, 0]})
mock_chain.assert_called_with(
[
_ = models.DownloadJob.get_signature({'bounding_box': [0, 20, 20, 0], 'publish': True})
mock_chain.assert_has_calls([
mock.call([
mock_tasks.download.signature.return_value,
mock_tasks.unarchive.signature.return_value,
mock_tasks.crop.signature.return_value,
]),
mock.call([
mock_tasks.download.signature.return_value,
mock_tasks.unarchive.signature.return_value,
mock_tasks.crop.signature.return_value,
mock_tasks.archive.signature.return_value,
mock_tasks.publish.signature.return_value,
])
])
self.assertListEqual(
mock_tasks.crop.signature.call_args[1]['kwargs']['bounding_box'],
[0, 20, 20, 0])
Expand Down Expand Up @@ -256,6 +269,15 @@ def test_check_parameters_wrong_bounding_box_type(self):
with self.assertRaises(ValidationError):
models.DownloadJob.check_parameters({'dataset_id': 1, 'bounding_box': [2]})

def test_check_parameters_wrong_publish_type(self):
"""`check_parameters()` must raise an exception if the
'publish' value is of the wrong type
"""
with self.assertRaises(ValidationError):
models.DownloadJob.check_parameters({'dataset_id': 1, 'publish': 'False'})
with self.assertRaises(ValidationError):
models.DownloadJob.check_parameters({'dataset_id': 1, 'publish': 1})


class ConvertJobTests(unittest.TestCase):
"""Tests for the ConvertJob class"""
Expand Down Expand Up @@ -426,6 +448,97 @@ def test_make_task_parameters(self):
(({'id': 539},), {}))


class SyntoolCompareJobTests(unittest.TestCase):
"""Tests for the SyntoolCompareJob class"""

def test_get_signature(self):
"""Test getting the right signature"""
with mock.patch(
'geospaas_rest_api.processing_api.models.tasks_syntool') as mock_syntool_tasks, \
mock.patch('celery.chain') as mock_chain:
_ = models.SyntoolCompareJob.get_signature({})
mock_chain.assert_called_once_with(
mock_syntool_tasks.compare_profiles.signature.return_value,
mock_syntool_tasks.db_insert.signature.return_value,
)

def test_check_parameters_ok(self):
"""Test that check_parameters() returns the parameters when
they are valid
"""
self.assertDictEqual(
models.SyntoolCompareJob.check_parameters({
'model': (123, '/foo'),
'profiles': ((456, '/bar'), (789, '/baz'))
}),
{'model': (123, '/foo'), 'profiles': ((456, '/bar'), (789, '/baz'))})

def test_check_parameters_unknown(self):
"""An error should be raised when an unknown parameter is given
"""
with self.assertRaises(ValidationError):
models.SyntoolCompareJob.check_parameters({'foo': 'bar'})

def test_check_parameters_no_criteria(self):
"""An error should be raised when the criteria parameter is
absent
"""
with self.assertRaises(ValidationError):
models.SyntoolCompareJob.check_parameters({})

def test_check_parameters_wrong_type(self):
"""An error should be raised when the parameters have the
wrong type
"""
# model not a sequence
with self.assertRaises(ValidationError):
models.SyntoolCompareJob.check_parameters(
{'model': 123, 'profiles': ((123, '/bar'),)})
# wrong model ID type
with self.assertRaises(ValidationError):
models.SyntoolCompareJob.check_parameters(
{'model': ('123', '/foo'), 'profiles': ((123, '/bar'),)})
# wrong model path type
with self.assertRaises(ValidationError):
models.SyntoolCompareJob.check_parameters(
{'model': (123, True), 'profiles': ((123, '/bar'),)})
# profile not a sequence
with self.assertRaises(ValidationError):
models.SyntoolCompareJob.check_parameters({
'model': (123, '/foo'),
'profiles': 456
})
# profile not a sequence of couples
with self.assertRaises(ValidationError):
models.SyntoolCompareJob.check_parameters({
'model': (123, '/foo'),
'profiles': (456, 789)
})
# wrong profile ID type
with self.assertRaises(ValidationError):
models.SyntoolCompareJob.check_parameters({
'model': (123, '/foo'),
'profiles': (('456', '/bar'), (789, '/baz'))
})
# wrong profile path type
with self.assertRaises(ValidationError):
models.SyntoolCompareJob.check_parameters({
'model': (123, '/foo'),
'profiles': ((456, '/bar'), (789, False))
})

def test_make_task_parameters(self):
"""Test that the right arguments are builts from the request
parameters
"""
self.assertTupleEqual(
models.SyntoolCompareJob.make_task_parameters({
'model': (123, '/foo'),
'profiles': ((456, '/bar'), (789, '/baz'))
}),
((((123, '/foo'), ((456, '/bar'), (789, '/baz'))),), {}))


class HarvestJobTests(unittest.TestCase):
"""Tests for the HarvestJob class"""

Expand Down

0 comments on commit ae4277b

Please sign in to comment.