diff --git a/README.md b/README.md index 30f7795..c140d21 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ The system has been designed using a cloud-native architecture, based on contain It uses [geokube](https://github.com/CMCC-Foundation/geokube) as an Analytics Engine to perform geospatial operations. -## Developer Team +## Developers Team - [Valentina Scardigno](https://github.com/vale95-eng) - [Gabriele Tramonte](https://github.com/gtramonte) diff --git a/api/app/endpoint_handlers/dataset.py b/api/app/endpoint_handlers/dataset.py index a3f8ca5..109eb0c 100644 --- a/api/app/endpoint_handlers/dataset.py +++ b/api/app/endpoint_handlers/dataset.py @@ -3,7 +3,9 @@ import pika from typing import Optional -from dbmanager.dbmanager import DBManager +from fastapi.responses import FileResponse + +from dbmanager.dbmanager import DBManager, RequestStatus from geoquery.geoquery import GeoQuery from geoquery.task import TaskList from datastore.datastore import Datastore, DEFAULT_MAX_REQUEST_SIZE_GB @@ -18,12 +20,18 @@ from api_utils import make_bytes_readable_dict from validation import assert_product_exists +from . import request log = get_dds_logger(__name__) data_store = Datastore() MESSAGE_SEPARATOR = os.environ["MESSAGE_SEPARATOR"] +def _is_etimate_enabled(dataset_id, product_id): + if dataset_id in ("sentinel-2",): + return False + return True + @log_execution_time(log) def get_datasets(user_roles_names: list[str]) -> list[dict]: @@ -213,7 +221,7 @@ def estimate( @log_execution_time(log) @assert_product_exists -def query( +def async_query( user_id: str, dataset_id: str, product_id: str, @@ -250,21 +258,22 @@ def query( """ log.debug("geoquery: %s", query) - estimated_size = estimate(dataset_id, product_id, query, "GB").get("value") - allowed_size = data_store.product_metadata(dataset_id, product_id).get( - "maximum_query_size_gb", DEFAULT_MAX_REQUEST_SIZE_GB - ) - if estimated_size > allowed_size: - raise exc.MaximumAllowedSizeExceededError( - dataset_id=dataset_id, - product_id=product_id, - estimated_size_gb=estimated_size, - allowed_size_gb=allowed_size, - ) - if estimated_size == 0.0: - raise exc.EmptyDatasetError( - dataset_id=dataset_id, product_id=product_id + if _is_etimate_enabled(dataset_id, product_id): + estimated_size = estimate(dataset_id, product_id, query, "GB").get("value") + allowed_size = data_store.product_metadata(dataset_id, product_id).get( + "maximum_query_size_gb", DEFAULT_MAX_REQUEST_SIZE_GB ) + if estimated_size > allowed_size: + raise exc.MaximumAllowedSizeExceededError( + dataset_id=dataset_id, + product_id=product_id, + estimated_size_gb=estimated_size, + allowed_size_gb=allowed_size, + ) + if estimated_size == 0.0: + raise exc.EmptyDatasetError( + dataset_id=dataset_id, product_id=product_id + ) broker_conn = pika.BlockingConnection( pika.ConnectionParameters( host=os.getenv("BROKER_SERVICE_HOST", "broker") @@ -295,6 +304,68 @@ def query( broker_conn.close() return request_id +@log_execution_time(log) +@assert_product_exists +def sync_query( + user_id: str, + dataset_id: str, + product_id: str, + query: GeoQuery, +): + """Realize the logic for the endpoint: + + `POST /datasets/{dataset_id}/{product_id}/execute` + + Query the data and return the result of the request. + + Parameters + ---------- + user_id : str + ID of the user executing the query + dataset_id : str + ID of the dataset + product_id : str + ID of the product + query : GeoQuery + Query to perform + + Returns + ------- + request_id : int + ID of the request + + Raises + ------- + MaximumAllowedSizeExceededError + if the allowed size is below the estimated one + EmptyDatasetError + if estimated size is zero + + """ + + import time + request_id = async_query(user_id, dataset_id, product_id, query) + status, _ = DBManager().get_request_status_and_reason(request_id) + log.debug("sync query: status: %s", status) + while status in (RequestStatus.RUNNING, RequestStatus.QUEUED, + RequestStatus.PENDING): + time.sleep(1) + status, _ = DBManager().get_request_status_and_reason(request_id) + log.debug("sync query: status: %s", status) + + if status is RequestStatus.DONE: + download_details = DBManager().get_download_details_for_request_id( + request_id + ) + return FileResponse( + path=download_details.location_path, + filename=download_details.location_path.split(os.sep)[-1], + ) + raise exc.ProductRetrievingError( + dataset_id=dataset_id, + product_id=product_id, + status=status.name) + @log_execution_time(log) def run_workflow( diff --git a/api/app/endpoint_handlers/request.py b/api/app/endpoint_handlers/request.py index 320bceb..93a0636 100644 --- a/api/app/endpoint_handlers/request.py +++ b/api/app/endpoint_handlers/request.py @@ -86,7 +86,11 @@ def get_request_resulting_size(request_id: int): If the request was not found """ if request := DBManager().get_request_details(request_id): - return request.download.size_bytes + size = request.download.size_bytes + if not size or size == 0: + raise exc.EmptyDatasetError(dataset_id=request.dataset, + product_id=request.product) + return size log.info( "request with id '%s' could not be found", request_id, diff --git a/api/app/exceptions.py b/api/app/exceptions.py index 01de71c..af4d072 100644 --- a/api/app/exceptions.py +++ b/api/app/exceptions.py @@ -180,3 +180,16 @@ def __init__(self, dataset_id, product_id): product_id=product_id, ) super().__init__(self.msg) + +class ProductRetrievingError(BaseDDSException): + """Retrieving of the product failed.""" + + msg: str = "Retrieving of the product '{dataset_id}.{product_id}' failed with the status {status}" + + def __init__(self, dataset_id, product_id, status): + self.msg = self.msg.format( + dataset_id=dataset_id, + product_id=product_id, + status=status + ) + super().__init__(self.msg) \ No newline at end of file diff --git a/api/app/main.py b/api/app/main.py index b7703f8..322f0de 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -1,9 +1,11 @@ """Main module with geolake API endpoints defined""" __version__ = "0.1.0" import os -from typing import Optional +import re +from typing import Optional, Dict +from datetime import datetime -from fastapi import FastAPI, HTTPException, Request, status +from fastapi import FastAPI, HTTPException, Request, status, Query, Response from fastapi.middleware.cors import CORSMiddleware from starlette.middleware.authentication import AuthenticationMiddleware from starlette.authentication import requires @@ -16,8 +18,8 @@ ) from aioprometheus.asgi.starlette import metrics -from geoquery.geoquery import GeoQuery from geoquery.task import TaskList +from geoquery.geoquery import GeoQuery from utils.api_logging import get_dds_logger import exceptions as exc @@ -32,6 +34,32 @@ from const import venv, tags from auth import scopes +def map_to_geoquery( + variables: list[str], + format: str, + bbox: str | None = None, # minx, miny, maxx, maxy (minlon, minlat, maxlon, maxlat) + time: datetime | None = None, + filters: Optional[Dict] = None, + **format_kwargs +) -> GeoQuery: + + if bbox: + bbox_ = [float(x) for x in bbox.split(',')] + area = { 'west': bbox_[0], 'south': bbox_[1], 'east': bbox_[2], 'north': bbox_[3], } + else: + area = None + if time: + time_ = { 'year': time.year, 'month': time.month, 'day': time.day, 'hour': time.hour} + else: + time_ = None + if filters: + query = GeoQuery(variable=variables, time=time_, area=area, filters=filters, + format_args=format_kwargs, format=format) + else: + query = GeoQuery(variable=variables, time=time_, area=area, + format_args=format_kwargs, format=format) + return query + logger = get_dds_logger(__name__) # ======== JSON encoders extension ========= # @@ -155,7 +183,251 @@ async def get_product_details( except exc.BaseDDSException as err: raise err.wrap_around_http_exception() from err +@app.get("/datasets/{dataset_id}/{product_id}/map", tags=[tags.DATASET]) +@timer( + app.state.api_request_duration_seconds, + labels={"route": "GET /datasets/{dataset_id}/{product_id}"}, +) +async def get_map( + request: Request, + dataset_id: str, + product_id: str, +# OGC WMS parameters + width: int, + height: int, + dpi: int | None = 100, + layers: str | None = None, + format: str | None = 'png', + time: datetime | None = None, + transparent: bool | None = 'true', + bgcolor: str | None = 'FFFFFF', + cmap: str | None = 'RdBu_r', + bbox: str | None = None, # minx, miny, maxx, maxy (minlon, minlat, maxlon, maxlat) + crs: str | None = None, + vmin: float | None = None, + vmax: float | None = None +# OGC map parameters + # subset: str | None = None, + # subset_crs: str | None = Query(..., alias="subset-crs"), + # bbox_crs: str | None = Query(..., alias="bbox-crs"), +): + + app.state.api_http_requests_total.inc( + {"route": "GET /datasets/{dataset_id}/{product_id}/map"} + ) + # query should be the OGC query + # map OGC parameters to GeoQuery + # variable: Optional[Union[str, List[str]]] + # time: Optional[Union[Dict[str, str], Dict[str, List[str]]]] + # area: Optional[Dict[str, float]] + # location: Optional[Dict[str, Union[float, List[float]]]] + # vertical: Optional[Union[float, List[float], Dict[str, float]]] + # filters: Optional[Dict] + # format: Optional[str] + + query = map_to_geoquery(variables=layers, bbox=bbox, time=time, + format="png", width=width, height=height, + transparent=transparent, bgcolor=bgcolor, + dpi=dpi, cmap=cmap, projection=crs, + vmin=vmin, vmax=vmax) + try: + return dataset_handler.sync_query( + user_id=request.user.id, + dataset_id=dataset_id, + product_id=product_id, + query=query + ) + except exc.BaseDDSException as err: + raise err.wrap_around_http_exception() from err + +@app.get("/datasets/{dataset_id}/{product_id}/{filters:path}/map", tags=[tags.DATASET]) +@timer( + app.state.api_request_duration_seconds, + labels={"route": "GET /datasets/{dataset_id}/{product_id}"}, +) +async def get_map_with_filters( + request: Request, + dataset_id: str, + product_id: str, + filters: str, +# OGC WMS parameters + width: int, + height: int, + dpi: int | None = 100, + layers: str | None = None, + format: str | None = 'png', + time: datetime | None = None, + transparent: bool | None = 'true', + bgcolor: str | None = 'FFFFFF', + cmap: str | None = 'RdBu_r', + bbox: str | None = None, # minx, miny, maxx, maxy (minlon, minlat, maxlon, maxlat) + crs: str | None = None, + vmin: float | None = None, + vmax: float | None = None +# OGC map parameters + # subset: str | None = None, + # subset_crs: str | None = Query(..., alias="subset-crs"), + # bbox_crs: str | None = Query(..., alias="bbox-crs"), +): + filters_vals = filters.split("/") + + if dataset_id in ['rs-indices', 'pasture']: + filters_dict = {'pasture': filters_vals[0]} + + else: + try: + product_info = dataset_handler.get_product_details( + user_roles_names=request.auth.scopes, + dataset_id=dataset_id, + product_id=product_id, + ) + except exc.BaseDDSException as err: + raise err.wrap_around_http_exception() from err + + filters_keys = product_info['metadata']['filters'] + filters_dict = {} + for i in range(0, len(filters_vals)): + filters_dict[filters_keys[i]['name']] = filters_vals[i] + + app.state.api_http_requests_total.inc( + {"route": "GET /datasets/{dataset_id}/{product_id}/map"} + ) + # query should be the OGC query + # map OGC parameters to GeoQuery + # variable: Optional[Union[str, List[str]]] + # time: Optional[Union[Dict[str, str], Dict[str, List[str]]]] + # area: Optional[Dict[str, float]] + # location: Optional[Dict[str, Union[float, List[float]]]] + # vertical: Optional[Union[float, List[float], Dict[str, float]]] + # filters: Optional[Dict] + # format: Optional[str] + + query = map_to_geoquery(variables=layers, bbox=bbox, time=time, filters=filters_dict, + format="png", width=width, height=height, + transparent=transparent, bgcolor=bgcolor, + dpi=dpi, cmap=cmap, projection=crs, vmin=vmin, vmax=vmax) + try: + return dataset_handler.sync_query( + user_id=request.user.id, + dataset_id=dataset_id, + product_id=product_id, + query=query + ) + except exc.BaseDDSException as err: + raise err.wrap_around_http_exception() from err + +@app.get("/datasets/{dataset_id}/{product_id}/items/{feature_id}", tags=[tags.DATASET]) +@timer( + app.state.api_request_duration_seconds, + labels={"route": "GET /datasets/{dataset_id}/{product_id}/items/{feature_id}"}, +) +async def get_feature( + request: Request, + dataset_id: str, + product_id: str, + feature_id: str, +# OGC feature parameters + time: datetime | None = None, + bbox: str | None = None, # minx, miny, maxx, maxy (minlon, minlat, maxlon, maxlat) + crs: str | None = None, +# OGC map parameters + # subset: str | None = None, + # subset_crs: str | None = Query(..., alias="subset-crs"), + # bbox_crs: str | None = Query(..., alias="bbox-crs"), +): + + app.state.api_http_requests_total.inc( + {"route": "GET /datasets/{dataset_id}/{product_id}/items/{feature_id}"} + ) + # query should be the OGC query + # feature OGC parameters to GeoQuery + # variable: Optional[Union[str, List[str]]] + # time: Optional[Union[Dict[str, str], Dict[str, List[str]]]] + # area: Optional[Dict[str, float]] + # location: Optional[Dict[str, Union[float, List[float]]]] + # vertical: Optional[Union[float, List[float], Dict[str, float]]] + # filters: Optional[Dict] + # format: Optional[str] + + query = map_to_geoquery(variables=[feature_id], bbox=bbox, time=time, + format="geojson") + try: + return dataset_handler.sync_query( + user_id=request.user.id, + dataset_id=dataset_id, + product_id=product_id, + query=query + ) + except exc.BaseDDSException as err: + raise err.wrap_around_http_exception() from err + +@app.get("/datasets/{dataset_id}/{product_id}/{filters:path}/items/{feature_id}", tags=[tags.DATASET]) +@timer( + app.state.api_request_duration_seconds, + labels={"route": "GET /datasets/{dataset_id}/{product_id}/items/{feature_id}"}, +) +async def get_feature_with_filters( + request: Request, + dataset_id: str, + product_id: str, + feature_id: str, + filters: str, +# OGC feature parameters + time: datetime | None = None, + bbox: str | None = None, # minx, miny, maxx, maxy (minlon, minlat, maxlon, maxlat) + crs: str | None = None, +# OGC map parameters + # subset: str | None = None, + # subset_crs: str | None = Query(..., alias="subset-crs"), + # bbox_crs: str | None = Query(..., alias="bbox-crs"), +): + filters_vals = filters.split("/") + + if dataset_id in ['rs-indices', 'pasture']: + filters_dict = {'pasture': filters_vals[0]} + + else: + try: + product_info = dataset_handler.get_product_details( + user_roles_names=request.auth.scopes, + dataset_id=dataset_id, + product_id=product_id, + ) + except exc.BaseDDSException as err: + raise err.wrap_around_http_exception() from err + + filters_keys = product_info['metadata']['filters'] + filters_dict = {} + for i in range(0, len(filters_vals)): + filters_dict[filters_keys[i]['name']] = filters_vals[i] + + app.state.api_http_requests_total.inc( + {"route": "GET /datasets/{dataset_id}/{product_id}/items/{feature_id}"} + ) + # query should be the OGC query + # feature OGC parameters to GeoQuery + # variable: Optional[Union[str, List[str]]] + # time: Optional[Union[Dict[str, str], Dict[str, List[str]]]] + # area: Optional[Dict[str, float]] + # location: Optional[Dict[str, Union[float, List[float]]]] + # vertical: Optional[Union[float, List[float], Dict[str, float]]] + # filters: Optional[Dict] + # format: Optional[str] + + query = map_to_geoquery(variables=[feature_id], bbox=bbox, time=time, filters=filters_dict, + format="geojson") + try: + return dataset_handler.sync_query( + user_id=request.user.id, + dataset_id=dataset_id, + product_id=product_id, + query=query + ) + except exc.BaseDDSException as err: + raise err.wrap_around_http_exception() from err + + @app.get("/datasets/{dataset_id}/{product_id}/metadata", tags=[tags.DATASET]) @timer( app.state.api_request_duration_seconds, @@ -222,7 +494,7 @@ async def query( {"route": "POST /datasets/{dataset_id}/{product_id}/execute"} ) try: - return dataset_handler.query( + return dataset_handler.async_query( user_id=request.user.id, dataset_id=dataset_id, product_id=product_id, diff --git a/api/requirements.txt b/api/requirements.txt index 97fcaf3..6066865 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -3,3 +3,5 @@ uvicorn pika sqlalchemy aioprometheus +pyoai==2.5.0 +rdflib==7.0.0 \ No newline at end of file diff --git a/datastore/datastore/datastore.py b/datastore/datastore/datastore.py index 3fb2bfc..d595595 100644 --- a/datastore/datastore/datastore.py +++ b/datastore/datastore/datastore.py @@ -85,15 +85,17 @@ def get_cached_product_or_read( return self.cache[dataset_id][product_id] @log_execution_time(_LOG) - def _load_cache(self): - if self.cache is None: + def _load_cache(self, datasets: list[str] | None = None): + if self.cache is None or datasets is None: self.cache = {} - for i, dataset_id in enumerate(self.dataset_list()): + datasets = self.dataset_list() + + for i, dataset_id in enumerate(datasets): self._LOG.info( "loading cache for `%s` (%d/%d)", dataset_id, i + 1, - len(self.dataset_list()), + len(datasets), ) self.cache[dataset_id] = {} for product_id in self.product_list(dataset_id): @@ -117,7 +119,7 @@ def _load_cache(self): dataset_id, product_id, exc_info=True, - ) + ) @log_execution_time(_LOG) def dataset_list(self) -> list: @@ -419,7 +421,10 @@ def is_product_valid_for_role( def _process_query(kube, query: GeoQuery, compute: None | bool = False): if isinstance(kube, Dataset): Datastore._LOG.debug("filtering with: %s", query.filters) - kube = kube.filter(**query.filters) + try: + kube = kube.filter(**query.filters) + except ValueError as err: + Datastore._LOG.warning("could not filter by one of the key: %s", err) Datastore._LOG.debug("resulting kube len: %s", len(kube)) if isinstance(kube, Delayed) and compute: kube = kube.compute() @@ -444,10 +449,10 @@ def _process_query(kube, query: GeoQuery, compute: None | bool = False): if query.vertical: Datastore._LOG.debug("subsetting by vertical...") if isinstance( - vertical := Datastore._maybe_convert_dict_slice_to_slice( - query.vertical - ), - slice, + vertical := Datastore._maybe_convert_dict_slice_to_slice( + query.vertical + ), + slice, ): method = None else: @@ -456,6 +461,9 @@ def _process_query(kube, query: GeoQuery, compute: None | bool = False): if query.resample: Datastore._LOG.debug("Applying resample...") kube = kube.resample(**query.resample) + if query.regrid: + if query.regrid == 'regular': + kube = kube.to_regular() return kube.compute() if compute else kube @staticmethod @@ -466,4 +474,4 @@ def _maybe_convert_dict_slice_to_slice(dict_vals): dict_vals.get("stop"), dict_vals.get("step"), ) - return dict_vals + return dict_vals \ No newline at end of file diff --git a/datastore/dbmanager/dbmanager.py b/datastore/dbmanager/dbmanager.py index b11c46c..d4ff293 100644 --- a/datastore/dbmanager/dbmanager.py +++ b/datastore/dbmanager/dbmanager.py @@ -43,6 +43,7 @@ class RequestStatus(Enum_): """Status of the Request""" PENDING = auto() + QUEUED = auto() RUNNING = auto() DONE = auto() FAILED = auto() @@ -85,7 +86,7 @@ class User(Base): String(255), nullable=False, unique=True, default=generate_key ) contact_name = Column(String(255)) - requests = relationship("Request") + requests = relationship("Request", lazy="dynamic") roles = relationship("Role", secondary=association_table, lazy="selectin") @@ -96,7 +97,7 @@ class Worker(Base): host = Column(String(255)) dask_scheduler_port = Column(Integer) dask_dashboard_address = Column(String(10)) - created_on = Column(DateTime, nullable=False) + created_on = Column(DateTime, default=datetime.now) class Request(Base): @@ -112,8 +113,8 @@ class Request(Base): product = Column(String(255)) query = Column(JSON()) estimate_size_bytes = Column(Integer) - created_on = Column(DateTime, nullable=False) - last_update = Column(DateTime) + created_on = Column(DateTime, default=datetime.now) + last_update = Column(DateTime, default=datetime.now, onupdate=datetime.now) fail_reason = Column(String(1000)) download = relationship("Download", uselist=False, lazy="selectin") @@ -128,7 +129,7 @@ class Download(Base): storage_id = Column(Integer, ForeignKey("storages.storage_id")) location_path = Column(String(255)) size_bytes = Column(Integer) - created_on = Column(DateTime, nullable=False) + created_on = Column(DateTime, default=datetime.now) class Storage(Base): @@ -267,16 +268,18 @@ def create_request( def update_request( self, request_id: int, - worker_id: int, - status: RequestStatus, + worker_id: int | None = None, + status: RequestStatus | None = None, location_path: str = None, size_bytes: int = None, fail_reason: str = None, ) -> int: with self.__session_maker() as session: request = session.query(Request).get(request_id) - request.status = status - request.worker_id = worker_id + if status: + request.status = status + if worker_id: + request.worker_id = worker_id request.last_update = datetime.utcnow() request.fail_reason = fail_reason session.commit() @@ -305,7 +308,17 @@ def get_request_status_and_reason( def get_requests_for_user_id(self, user_id) -> list[Request]: with self.__session_maker() as session: - return session.query(User).get(user_id).requests + return session.query(User).get(user_id).requests.all() + + def get_requests_for_user_id_and_status( + self, user_id, status: RequestStatus | tuple[RequestStatus] + ) -> list[Request]: + if isinstance(status, RequestStatus): + status = (status,) + with self.__session_maker() as session: + return session.get(User, user_id).requests.filter( + Request.status.in_(status) + ) def get_download_details_for_request_id(self, request_id) -> Download: with self.__session_maker() as session: diff --git a/datastore/geoquery/geoquery.py b/datastore/geoquery/geoquery.py index 540ea6f..e61288d 100644 --- a/datastore/geoquery/geoquery.py +++ b/datastore/geoquery/geoquery.py @@ -16,6 +16,8 @@ class GeoQuery(BaseModel, extra="allow"): vertical: Optional[Union[float, List[float], Dict[str, float]]] filters: Optional[Dict] format: Optional[str] + format_args: Optional[Dict] + regrid: Optional[str] # TODO: Check if we are going to allow the vertical coordinates inside both # `area`/`location` nad `vertical` diff --git a/datastore/workflow/workflow.py b/datastore/workflow/workflow.py index e609a77..04720ee 100644 --- a/datastore/workflow/workflow.py +++ b/datastore/workflow/workflow.py @@ -108,9 +108,9 @@ def _subset(kube: DataCube | None = None) -> DataCube: return Datastore().query( dataset_id=dataset_id, product_id=product_id, - query=query - if isinstance(query, GeoQuery) - else GeoQuery(**query), + query=( + query if isinstance(query, GeoQuery) else GeoQuery(**query) + ), compute=False, ) @@ -153,18 +153,21 @@ def _average(kube: DataCube | None = None) -> DataCube: ) self._add_computational_node(task) return self - + def to_regular( self, id: Hashable, *, dependencies: list[Hashable] ) -> "Workflow": def _to_regular(kube: DataCube | None = None) -> DataCube: - assert kube is not None, "`kube` cannot be `None` for `to_regular``" + assert ( + kube is not None + ), "`kube` cannot be `None` for `to_regular``" return kube.to_regular() + task = _WorkflowTask( id=id, operator=_to_regular, dependencies=dependencies ) self._add_computational_node(task) - return self + return self def add_task( self, diff --git a/drivers/Dockerfile b/drivers/Dockerfile index 212826b..e45aeab 100644 --- a/drivers/Dockerfile +++ b/drivers/Dockerfile @@ -1,7 +1,7 @@ ARG REGISTRY=rg.fr-par.scw.cloud/geokube #ARG TAG=v0.2.6b2 #ARG TAG=2024.05.03.10.36 -ARG TAG=v0.2.7.1 +ARG TAG=v0.2.7.2 FROM $REGISTRY/geokube:$TAG ADD . /opt/intake_geokube diff --git a/drivers/intake_geokube/__init__.py b/drivers/intake_geokube/__init__.py index dc60a1d..c4be038 100644 --- a/drivers/intake_geokube/__init__.py +++ b/drivers/intake_geokube/__init__.py @@ -2,4 +2,4 @@ # This avoids a circilar dependency pitfall by ensuring that the # driver-discovery code runs first, see: -# https://intake.readthedocs.io/en/latest/making-plugins.html#entrypoints +# https://intake.readthedocs.io/en/latest/making-plugins.html#entrypoints \ No newline at end of file diff --git a/drivers/intake_geokube/base.py b/drivers/intake_geokube/base.py index e3c689b..90f3ef8 100644 --- a/drivers/intake_geokube/base.py +++ b/drivers/intake_geokube/base.py @@ -1,9 +1,11 @@ # from . import __version__ +from dask.delayed import Delayed from intake.source.base import DataSource, Schema from geokube.core.datacube import DataCube from geokube.core.dataset import Dataset + class GeokubeSource(DataSource): """Common behaviours for plugins in this repo""" @@ -56,9 +58,8 @@ def read(self): def read_chunked(self): """Return a lazy geokube object""" - self._load_metadata() - return self._kube - + return self.read() + def read_partition(self, i): """Fetch one chunk of data at tuple index i""" raise NotImplementedError @@ -75,4 +76,4 @@ def to_pyarrow(self): def close(self): """Delete open file from memory""" self._kube = None - self._schema = None + self._schema = None \ No newline at end of file diff --git a/drivers/intake_geokube/netcdf.py b/drivers/intake_geokube/netcdf.py index 7247891..e8b0754 100644 --- a/drivers/intake_geokube/netcdf.py +++ b/drivers/intake_geokube/netcdf.py @@ -21,6 +21,7 @@ def __init__( metadata=None, mapping: Optional[Mapping[str, Mapping[str, str]]] = None, load_files_on_persistance: Optional[bool] = True, + **kwargs ): self._kube = None self.path = path @@ -34,7 +35,7 @@ def __init__( self.xarray_kwargs = {} if xarray_kwargs is None else xarray_kwargs self.load_files_on_persistance = load_files_on_persistance # self.xarray_kwargs.update({'engine' : 'netcdf'}) - super(NetCDFSource, self).__init__(metadata=metadata) + super(NetCDFSource, self).__init__(metadata=metadata, **kwargs) def _open_dataset(self): if self.pattern is None: diff --git a/drivers/intake_geokube/sentinel.py b/drivers/intake_geokube/sentinel.py new file mode 100644 index 0000000..4c6b612 --- /dev/null +++ b/drivers/intake_geokube/sentinel.py @@ -0,0 +1,205 @@ +"""Geokube driver for sentinel data.""" + +from collections import defaultdict +from multiprocessing.util import get_temp_dir +import os +import dask +import zipfile +import glob +from functools import partial +from typing import Generator, Iterable, Mapping, Optional, List + +import numpy as np +import pandas as pd +import xarray as xr +from pyproj import Transformer +from pyproj.crs import CRS, GeographicCRS +from intake.source.utils import reverse_format + +from geokube import open_datacube +from geokube.core.dataset import Dataset + +from .base import GeokubeSource +from .geoquery import GeoQuery + +SENSING_TIME_ATTR: str = "sensing_time" +FILE: str = "files" +DATACUBE: str = "datacube" + + +def get_field_name_from_path(path: str): + res, file = path.split(os.sep)[-2:] + band = file.split("_")[-2] + return f"{res}_{band}" + + +def preprocess_sentinel(dset: xr.Dataset, pattern: str, **kw) -> xr.Dataset: + crs = CRS.from_cf(dset["spatial_ref"].attrs) + transformer = Transformer.from_crs( + crs_from=crs, crs_to=GeographicCRS(), always_xy=True + ) + x_vals, y_vals = dset["x"].to_numpy(), dset["y"].to_numpy() + lon_vals, lat_vals = transformer.transform(*np.meshgrid(x_vals, y_vals)) + source_path = dset.encoding["source"] + sensing_time = os.path.splitext(source_path.split(os.sep)[-6])[0].split( + "_" + )[-1] + time = pd.to_datetime([sensing_time]).to_numpy() + dset = dset.assign_coords( + { + "time": time, + "latitude": (("x", "y"), lat_vals), + "longitude": (("x", "y"), lon_vals), + } + ).rename({"band_data": get_field_name_from_path(source_path)}) + return dset + + +def get_zip_files_from_path(path: str) -> Generator: + assert path and isinstance(path, str), "`path` must be a string" + assert path.lower().endswith("zip"), "`path` must point to a ZIP archive" + if "*" in path: + yield from glob.iglob(path) + return + yield path + + +def unzip_data(files: Iterable[str], target: str) -> List[str]: + """Unzip ZIP archive to the `target` directory.""" + target_files = [] + for file in files: + prod_id = os.path.splitext(os.path.basename(file))[0] + target_prod = os.path.join(target, prod_id) + os.makedirs(target_prod, exist_ok=True) + with zipfile.ZipFile(file) as archive: + archive.extractall(path=target_prod) + target_files.append(os.listdir(target_prod)) + return target_files + + +def _prepare_df_from_files(files: Iterable[str], pattern: str) -> pd.DataFrame: + data = [] + for f in files: + attr = reverse_format(pattern, f) + attr[FILE] = f + data.append(attr) + return pd.DataFrame(data) + + +class CMCCSentinelSource(GeokubeSource): + name = "cmcc_sentinel_geokube" + version = "0.0.1" + + def __init__( + self, + path: str, + pattern: str = None, + zippath: str = None, + zippattern: str = None, + metadata=None, + xarray_kwargs: dict = None, + mapping: Optional[Mapping[str, Mapping[str, str]]] = None, + **kwargs, + ): + super().__init__(metadata=metadata, **kwargs) + self._kube = None + self.path = path + self.pattern = pattern + self.zippath = zippath + self.zippattern = zippattern + self.mapping = mapping + self.metadata_caching = False + self.xarray_kwargs = {} if xarray_kwargs is None else xarray_kwargs + self._unzip_dir = get_temp_dir() + self._zipdf = None + self._jp2df = None + assert ( + SENSING_TIME_ATTR in self.pattern + ), f"{SENSING_TIME_ATTR} is missing in the pattern" + self.preprocess = partial( + preprocess_sentinel, + pattern=self.pattern, + ) + if self.geoquery: + self.filters = self.geoquery.filters + else: + self.filters = {} + + def __post_init__(self) -> None: + assert ( + SENSING_TIME_ATTR in self.pattern + ), f"{SENSING_TIME_ATTR} is missing in the pattern" + self.preprocess = partial( + preprocess_sentinel, + pattern=self.pattern, + ) + + def _compute_res_df(self) -> List[str]: + self._zipdf = self._get_files_attr() + self._maybe_select_by_zip_attrs() + _ = unzip_data(self._zipdf[FILE].values, target=self._unzip_dir) + self._create_jp2_df() + self._maybe_select_by_jp2_attrs() + + def _get_files_attr(self) -> pd.DataFrame: + df = _prepare_df_from_files( + get_zip_files_from_path(self.path), self.pattern + ) + assert ( + SENSING_TIME_ATTR in df + ), f"{SENSING_TIME_ATTR} column is missing" + return df.set_index(SENSING_TIME_ATTR).sort_index() + + def _maybe_select_by_zip_attrs(self) -> Optional[pd.DataFrame]: + filters_to_pop = [] + for flt in self.filters: + if flt in self._zipdf.columns: + self._zipdf = self._zipdf.set_index(flt) + if flt == self._zipdf.index.name: + self._zipdf = self._zipdf.loc[self.filters[flt]] + filters_to_pop.append(flt) + for f in filters_to_pop: + self.filters.pop(f) + self._zipdf = self._zipdf.reset_index() + + + def _create_jp2_df(self) -> None: + self._jp2df = _prepare_df_from_files( + glob.iglob(os.path.join(self._unzip_dir, self.zippath)), + os.path.join(self._unzip_dir, self.zippattern), + ) + + def _maybe_select_by_jp2_attrs(self): + filters_to_pop = [] + for key, value in self.filters.items(): + if key not in self._jp2df: + continue + if isinstance(value, str): + self._jp2df = self._jp2df[self._jp2df[key] == value] + elif isinstance(value, Iterable): + self._jp2df = self._jp2df[self._jp2df[key].isin(value)] + else: + raise TypeError(f"type `{type(value)}` is not supported!") + filters_to_pop.append(key) + for f in filters_to_pop: + self.filters.pop(f) + + def _open_dataset(self): + self._compute_res_df() + self._jp2df + cubes = [] + for i, row in self._jp2df.iterrows(): + cubes.append( + dask.delayed(open_datacube)( + path=row[FILE], + id_pattern=None, + mapping=self.mapping, + metadata_caching=self.metadata_caching, + **self.xarray_kwargs, + preprocess=self.preprocess, + ) + ) + self._jp2df[DATACUBE] = cubes + self._kube = Dataset(self._jp2df.reset_index(drop=True)) + self.geoquery.filters = self.filters + return self._kube diff --git a/drivers/intake_geokube/wrf.py b/drivers/intake_geokube/wrf.py index 1968e40..196ff68 100644 --- a/drivers/intake_geokube/wrf.py +++ b/drivers/intake_geokube/wrf.py @@ -124,6 +124,7 @@ def __init__( load_files_on_persistance: Optional[bool] = True, variables_to_keep: Optional[Union[str, list[str]]] = None, variables_to_skip: Optional[Union[str, list[str]]] = None, + **kwargs ): self._kube = None self.path = path @@ -142,7 +143,7 @@ def __init__( variables_to_skip=variables_to_skip, ) # self.xarray_kwargs.update({'engine' : 'netcdf'}) - super(CMCCWRFSource, self).__init__(metadata=metadata) + super(CMCCWRFSource, self).__init__(metadata=metadata, **kwargs) def _open_dataset(self): if self.pattern is None: diff --git a/drivers/setup.py b/drivers/setup.py index b3a3032..7723f8b 100644 --- a/drivers/setup.py +++ b/drivers/setup.py @@ -13,12 +13,12 @@ long_description_content_type="text/markdown", url="https://github.com/geokube/intake-geokube", packages=setuptools.find_packages(), - install_requires=["intake", "pytest"], + install_requires=["intake", "pytest", "pydantic<2.0.0"], entry_points={ "intake.drivers": [ "geokube_netcdf = intake_geokube.netcdf:NetCDFSource", "cmcc_wrf_geokube = intake_geokube.wrf:CMCCWRFSource", - "cmcc_afm_geokube = intake_geokube.afm:CMCCAFMSource", + "cmcc_sentinel_geokube = intake_geokube.sentinel:CMCCSentinelSource" ] }, classifiers=[ diff --git a/executor/app/main.py b/executor/app/main.py index b6b5fa5..1c27148 100644 --- a/executor/app/main.py +++ b/executor/app/main.py @@ -111,6 +111,7 @@ def persist_datacube( kube._properties["history"] = get_history_message() if isinstance(message.content, GeoQuery): format = message.content.format + format_args = message.content.format_args else: format = "netcdf" match format: @@ -120,6 +121,15 @@ def persist_datacube( case "geojson": full_path = os.path.join(base_path, f"{path}.json") kube.to_geojson(full_path) + case "png": + full_path = os.path.join(base_path, f"{path}.png") + kube.to_image(full_path, **format_args) + case "jpeg": + full_path = os.path.join(base_path, f"{path}.jpg") + kube.to_image(full_path, **format_args) + case "csv": + full_path = os.path.join(base_path, f"{path}.csv") + kube.to_csv(full_path) case "zarr": full_path = os.path.join(base_path, f"{path}.zarr") kube.to_zarr(full_path, mode='w', consolidated=True) @@ -136,7 +146,9 @@ def persist_dataset( def _get_attr_comb(dataframe_item, attrs): return "_".join([dataframe_item[attr_name] for attr_name in attrs]) - def _persist_single_datacube(dataframe_item, base_path, format): + def _persist_single_datacube(dataframe_item, base_path, format, format_args=None): + if not format_args: + format_args = {} dcube = dataframe_item[dset.DATACUBE_COL] if isinstance(dcube, Delayed): dcube = dcube.compute() @@ -173,14 +185,24 @@ def _persist_single_datacube(dataframe_item, base_path, format): case "geojson": full_path = os.path.join(base_path, f"{path}.json") dcube.to_geojson(full_path) + case "png": + full_path = os.path.join(base_path, f"{path}.png") + dcube.to_image(full_path, **format_args) + case "jpeg": + full_path = os.path.join(base_path, f"{path}.jpg") + dcube.to_image(full_path, **format_args) + case "csv": + full_path = os.path.join(base_path, f"{path}.csv") + dcube.to_csv(full_path) return full_path if isinstance(message.content, GeoQuery): format = message.content.format + format_args = message.content.format_args else: format = "netcdf" datacubes_paths = dset.data.apply( - _persist_single_datacube, base_path=base_path, format=format, axis=1 + _persist_single_datacube, base_path=base_path, format=format, format_args=format_args, axis=1 ) paths = datacubes_paths[~datacubes_paths.isna()] if len(paths) == 0: