From 8de4da393daef76d4d5285e1fca95751466c6dd4 Mon Sep 17 00:00:00 2001 From: BRAUN REMI Date: Fri, 28 Feb 2025 10:06:09 +0100 Subject: [PATCH] FIX: Better manage default options in `rasters.write`, allowing to write easily with other drivers than `GTiff` or `COG` (such as `Zarr`) --- CHANGES.md | 1 + ci/test_rasters.py | 8 ++++++++ sertit/rasters.py | 27 +++++++++++++++++++-------- 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index f67d14a..1bc8dc4 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -5,6 +5,7 @@ - **ENH: Add two functions for converting degrees to and from meters: `rasters.from_deg_to_meters` and `rasters.from_meters_to_deg`** - **ENH: Add the ability to return a `dask.delayed` object in `rasters.write`** - **ENH: Save the path of the opened files in attributes in `rasters.read`** +- FIX: Better manage default options in `rasters.write`, allowing to write easily with other drivers than `GTiff` or `COG` (such as `Zarr`) - FIX: Don't take nodata value into account in `ci.assert_raster_almost_equal_magnitude` ## 1.45.2 (2025-02-17) diff --git a/ci/test_rasters.py b/ci/test_rasters.py index 2119805..6acc4b0 100644 --- a/ci/test_rasters.py +++ b/ci/test_rasters.py @@ -744,6 +744,14 @@ def test_write(dtype, nodata_val, tmp_path, xda): rasters.write(xda, path=test_deprecated_path, dtype=dtype) +def test_write_zarr(tmp_path, xda): + # test zarr + zarr_path = os.path.join(tmp_path, "z.zarr") + rasters.write(xda, path=zarr_path, driver="Zarr") + # Just test to read the zarr array + np.testing.assert_array_equal(xda.data, rasters.read(zarr_path).data) + + def test_dim(): """Test on BEAM-DIMAP function""" dim_path = rasters_path().joinpath("DIM.dim") diff --git a/sertit/rasters.py b/sertit/rasters.py index 3199d88..2b49fbd 100644 --- a/sertit/rasters.py +++ b/sertit/rasters.py @@ -1162,11 +1162,14 @@ def write( # Bigtiff if needed bigtiff = rasters_rio.bigtiff_value(xds) - # Force GTiff + # Force GTiff by default kwargs["driver"] = kwargs.get("driver", "GTiff") # Manage COGs or other drivers attributes is_cog = kwargs["driver"] == "COG" + is_gtiff = kwargs["driver"] == "GTiff" + is_zarr = kwargs["driver"] == "Zarr" + if is_cog: kwargs.pop("tiled", None) @@ -1179,7 +1182,10 @@ def write( "Your data will be converted to uint8. " "In case of casting issues (i.e. negative values), please save it to int16." ) - + elif is_zarr: + # Get default client's lock + kwargs["lock"] = kwargs.get("lock", dask.get_dask_lock("rio")) + kwargs["compress"] = kwargs.get("compress", "zstd") else: # Get default client's lock kwargs["lock"] = kwargs.get("lock", dask.get_dask_lock("rio")) @@ -1188,11 +1194,14 @@ def write( kwargs["tiled"] = kwargs.get("tiled", True) # Default compression to LZW - kwargs["compress"] = kwargs.get("compress", "lzw") + if is_gtiff: + kwargs["compress"] = kwargs.get("compress", "lzw") + # Else, don't set any default compression # Manage predictors according to dtype and compression if ( - kwargs["compress"].lower() in ["lzw", "deflate", "zstd"] + not is_zarr + and kwargs["compress"].lower() in ["lzw", "deflate", "zstd"] and "predictor" not in kwargs # noqa: W503 ): if xds.encoding["dtype"] in [np.float16, np.float32, np.float64, float]: @@ -1208,7 +1217,7 @@ def write( if write_cogs_with_dask: try: - LOGGER.debug("Writing your COG with Dask!") + LOGGER.debug("Writing your COG with Dask.") # Filter out and convert kwargs to avoid any error da_kwargs = { @@ -1269,16 +1278,18 @@ def write( # Default write on disk if not is_written: - LOGGER.debug(f"Writing your file '{path.get_filename(output_path)}' to disk.") + LOGGER.debug(f"Writing '{path.get_filename(output_path)}' to disk.") # WORKAROUND: Pop _FillValue attribute (if existing) if "_FillValue" in xds.attrs: xds.attrs.pop("_FillValue") + if not is_zarr: + kwargs["BIGTIFF"] = bigtiff + kwargs["NUM_THREADS"] = MAX_CORES + delayed = xds.rio.to_raster( str(output_path), - BIGTIFF=bigtiff, - NUM_THREADS=MAX_CORES, tags=tags, compute=compute, **misc.remove_empty_values(kwargs),