diff --git a/dascore/data_registry.txt b/dascore/data_registry.txt index ba59297e..0ca73fb7 100644 --- a/dascore/data_registry.txt +++ b/dascore/data_registry.txt @@ -28,3 +28,4 @@ neubrex_dts_forge.h5 940f7bea6dd4c8a1340b4936b8eb7f9edc577cbcaf77c1f5ac295890f88 decimated_optodas.hdf5 48ce9c2ab4916d5536faeef0bd789f326ec4afc232729d32014d4d835a9fb74e https://github.com/dasdae/test_data/raw/master/das/decimated_optodas.hdf5 neubrex_das_1.h5 48a97e27c56e66cc2954ba4eaaadd2169919bb8f897d78e95ef6ab50abb5027b https://github.com/dasdae/test_data/raw/master/das/neubrex_das_1.h5 UoU_lf_urban.hdf5 d3f8fa6ff3d8ae993484b3fbf9b39505e2cf15cb3b39925a3519b27d5fbe7b5b https://github.com/dasdae/test_data/raw/master/das/UoU_lf_urban.hdf5 +small_channel_patch.sgy 31e551aadb361189c1c9325d504c883114ba9a7bb75fe4791e5089fabccef704 https://github.com/dasdae/test_data/raw/master/das/small_channel_patch.sgy diff --git a/dascore/io/core.py b/dascore/io/core.py index 09a3cfba..1839c55b 100644 --- a/dascore/io/core.py +++ b/dascore/io/core.py @@ -30,10 +30,14 @@ timeable_types, ) from dascore.core.attrs import str_validator -from dascore.exceptions import InvalidFiberIOError, UnknownFiberFormatError +from dascore.exceptions import ( + InvalidFiberIOError, + MissingOptionalDependencyError, + UnknownFiberFormatError, +) from dascore.utils.io import IOResourceManager, get_handle_from_resource from dascore.utils.mapping import FrozenDict -from dascore.utils.misc import _iter_filesystem, cached_method, iterate +from dascore.utils.misc import _iter_filesystem, cached_method, iterate, warn_or_raise from dascore.utils.models import ( CommaSeparatedStr, DascoreBaseModel, @@ -748,6 +752,27 @@ def _count_generator(generator): return entity_count +def _handle_missing_optionals(outputs, optional_dep_dict): + """ + Inform the user there are files that can be read but the proper + dependencies are not installed. + + If there are other readable files that were found, raise a warning. + Otherwise, raise a MissingOptionalDependencyError. + """ + msg = ( + f"DASCore found files that can be read if additional packages are " + f"installed. The needed packages and the found number of files are: " + f"{dict(optional_dep_dict)}" + ) + warn_or_raise( + msg, + exception=MissingOptionalDependencyError, + warning=UserWarning, + behavior="warn" if len(outputs) else "raise", + ) + + def scan( path: Path | str | PatchType | SpoolType | IOResourceManager, file_format: str | None = None, @@ -796,6 +821,8 @@ def scan( """ out = [] fiber_io_hint: dict[str, FiberIO] = {} + # A dict for keeping track of missing optional dependencies. + missing_optional_deps = defaultdict(lambda: 0) # Unfortunately, we have to iterate the scan candidates twice to get # an estimate for the progress bar length. Maybe there is a better way... _generator = _iterate_scan_inputs( @@ -826,6 +853,7 @@ def scan( except UnknownFiberFormatError: # skip bad entities continue # Cache this fiber io to given preferential treatment next iteration. + # This speeds up the common case of many files with the same format. fiber_io_hint[fiber_io.input_type] = fiber_io # Special handling of directory FiberIOs. if fiber_io.input_type == "directory": @@ -843,8 +871,13 @@ def scan( except OSError: # This happens if the file is corrupt see #346. warnings.warn(f"Failed to scan {resource}", UserWarning) continue + except MissingOptionalDependencyError as ex: + missing_optional_deps[ex.msg.split(" ")[0]] += 1 + continue for attr in source: out.append(dc.PatchAttrs.from_dict(attr)) + if missing_optional_deps: + _handle_missing_optionals(out, missing_optional_deps) return out diff --git a/dascore/io/segy/__init__.py b/dascore/io/segy/__init__.py index 2a421dd7..5cdef29e 100644 --- a/dascore/io/segy/__init__.py +++ b/dascore/io/segy/__init__.py @@ -5,6 +5,13 @@ ----- - Distance information is not found in most SEGY DAS files so returned dimensions are "channel" and "time" rather than "distance" and "time". +- Segy standards found at: https://library.seg.org/pb-assets/technical-standards + +segy v1 spec: seg_y_rev1-1686080991247.pdf + +segy v2 spec: seg_y_rev2_0-mar2017-1686080998003.pdf + +segy v2.1 spec: seg_y_rev2_1-oct2023-1701361639333.pdf Examples -------- @@ -17,4 +24,4 @@ segy_patch = dc.spool(path)[0] """ -from .core import SegyV2 +from .core import SegyV1_0, SegyV2_0, SegyV2_1 diff --git a/dascore/io/segy/core.py b/dascore/io/segy/core.py index 5e82c5bf..a44ceab8 100644 --- a/dascore/io/segy/core.py +++ b/dascore/io/segy/core.py @@ -2,30 +2,35 @@ from __future__ import annotations -import segyio - import dascore as dc from dascore.io.core import FiberIO +from dascore.utils.io import BinaryReader +from dascore.utils.misc import optional_import -from .utils import _get_attrs, _get_coords, _get_filtered_data_and_coords +from .utils import ( + _get_attrs, + _get_coords, + _get_filtered_data_and_coords, + _get_segy_version, + _write_segy, +) -class SegyV2(FiberIO): - """An IO class supporting version 2 of the SEGY format.""" +class SegyV1_0(FiberIO): # noqa + """An IO class supporting version 1.0 of the SEGY format.""" name = "segy" preferred_extensions = ("segy", "sgy") # also specify a version so when version 2 is released you can # just make another class in the same module named JingleV2. - version = "2" + version = "1.0" + # The name of the package to import. This is here so the class can be + # subclassed and this changed for debugging reasons. + _package_name = "segyio" - def get_format(self, path, **kwargs) -> tuple[str, str] | bool: + def get_format(self, fp: BinaryReader, **kwargs) -> tuple[str, str] | bool: """Make sure input is segy.""" - try: - with segyio.open(path, ignore_geometry=True): - return self.name, self.version - except Exception: - return False + return _get_segy_version(fp) def read(self, path, time=None, channel=None, **kwargs): """ @@ -35,6 +40,7 @@ def read(self, path, time=None, channel=None, **kwargs): accept kwargs. If the format supports partial reads, these should be implemented as well. """ + segyio = optional_import(self._package_name) with segyio.open(path, ignore_geometry=True) as fi: coords = _get_coords(fi) attrs = _get_attrs(fi, coords, path, self) @@ -55,7 +61,39 @@ def scan(self, path, **kwargs) -> list[dc.PatchAttrs]: from the [dascore.core.attrs](`dascore.core.attrs`) module, or a format-specific subclass. """ + segyio = optional_import(self._package_name) with segyio.open(path, ignore_geometry=True) as fi: coords = _get_coords(fi) attrs = _get_attrs(fi, coords, path, self) return [attrs] + + def write(self, spool: dc.Patch | dc.BaseSpool, resource, **kwargs): + """ + Create a segy file from length 1 spool or patch. + + Parameters + ---------- + spool + The patch or length 1 spool to write. + resource + The target for writing patch. + + Notes + ----- + Based on the example from segyio: + https://github.com/equinor/segyio/blob/master/python/examples/make-file.py + """ + segyio = optional_import(self._package_name) + _write_segy(spool, resource, self.version, segyio) + + +class SegyV2_0(SegyV1_0): # noqa + """An IO class supporting version 2.0 of the SEGY format.""" + + version = "2.0" + + +class SegyV2_1(SegyV1_0): # noqa + """An IO class supporting version 2.1 of the SEGY format.""" + + version = "2.1" diff --git a/dascore/io/segy/utils.py b/dascore/io/segy/utils.py index 97f4e75e..0b44ac48 100644 --- a/dascore/io/segy/utils.py +++ b/dascore/io/segy/utils.py @@ -3,14 +3,70 @@ from __future__ import annotations import datetime +import warnings import numpy as np -from segyio import TraceField + +# --- Getting format/version +import pandas as pd import dascore as dc +from dascore import to_float from dascore.core import get_coord_manager +from dascore.exceptions import InvalidSpoolError, PatchError +from dascore.utils.misc import optional_import + + +def twos_comp(bytes_): + """Get twos complement of bytestring.""" + bits = len(bytes_) * 8 + val = int.from_bytes(bytes_, "big") + if (val & (1 << (bits - 1))) != 0: # if sign bit is set e.g., 8bit: 128-255 + val = val - (1 << bits) # compute negative value + return val # return positive value as is -# --- Getting format/version + +def _get_segy_version(fp): + """ + Determine if file handle contains segy data. + + Returns (segy, version) if so else False. + + Based on ObsPy's implementation writen by Lion Krischer. + https://github.com/obspy/obspy/blob/master/obspy/io/segy/core.py + """ + # # Read 400byte header into byte string. + fp.seek(3200) + header = fp.read(400) + data_trace_count = twos_comp(header[12:14]) + auxiliary_trace_count = twos_comp(header[14:16]) + sample_interval = twos_comp(header[16:18]) + samples_per_trace = twos_comp(header[20:22]) + data_format_code = twos_comp(header[24:26]) + format_number_major = twos_comp(header[300:301]) + format_number_minor = twos_comp(header[301:302]) + fixed_len_flag = twos_comp(header[302:304]) + + checks = ( + # First check that some samples are defined. + samples_per_trace > 0, + # Then ensure the sample intervals is defined. This can be defined in trace + # header so 0 is ok, but not negative numbers. + sample_interval >= 0, + # Ensure the data sample format code is valid using range in 2,1 standard + 1 <= data_format_code <= 16, + # Check version code + format_number_major in {0, 1, 2, 3}, + # Sanity checks for other values. + data_trace_count >= 0, + auxiliary_trace_count >= 0, + format_number_minor in {0, 1, 2, 3}, + fixed_len_flag in {0, 1}, + ) + if all(checks): + return "segy", f"{format_number_major}.{format_number_minor}" + else: + return False def _get_filtered_data_and_coords(segy_fi, coords, time=None, channel=None): @@ -53,12 +109,14 @@ def _get_coords(fi): If a user knows the dx, change from channel to distance using patch.update_coords after reading """ + segyio = optional_import("segyio") + trace_field = segyio.TraceField header_0 = fi.header[0] - # get time array from SEGY headers + # Get time array from SEGY headers starttime = _get_time_from_header(header_0) - dt = dc.to_timedelta64(header_0[TraceField.TRACE_SAMPLE_INTERVAL] / 1000) - ns = header_0[TraceField.TRACE_SAMPLE_COUNT] + dt = dc.to_timedelta64(header_0[trace_field.TRACE_SAMPLE_INTERVAL] / 1_000_000) + ns = header_0[trace_field.TRACE_SAMPLE_COUNT] time_array = starttime + dt * np.arange(ns) # Get distance array from SEGY header @@ -83,13 +141,136 @@ def _get_attrs(fi, coords, path, file_io): def _get_time_from_header(header): """Creates a datetime64 object from SEGY header date information.""" - year = header[TraceField.YearDataRecorded] - julday = header[TraceField.DayOfYear] - hour = header[TraceField.HourOfDay] - minute = header[TraceField.MinuteOfHour] - second = header[TraceField.SecondOfMinute] + segyio = optional_import("segyio") + trace_field = segyio.TraceField + + year = header[trace_field.YearDataRecorded] + julday = header[trace_field.DayOfYear] + hour = header[trace_field.HourOfDay] + minute = header[trace_field.MinuteOfHour] + second = header[trace_field.SecondOfMinute] # make those timedate64 fmt = "%Y.%j.%H.%M.%S" s = f"{year}.{julday}.{hour}.{minute}.{second}" time = datetime.datetime.strptime(s, fmt) return dc.to_datetime64(time) + + +def _get_patch_with_channel_coord(patch): + """Ensure the patch has a channel coordinate.""" + dims = set(patch.dims) + non_time = next(iter(dims - {"time"})) + msg = ( + "Currently the segy writer only handles 'channel' as the non-time " + "dimension this results in a loss of the '{non_time}' dimension." + ) + warnings.warn(msg) + coord = patch.get_coord(non_time) + array = np.arange(len(coord)) + patch = patch.update_coords(**{non_time: array}).rename_coords( + **{non_time: "channel"} + ) + return patch + + +def _get_segy_compatible_patch(spool, round_error_max=3e-9): + """ + Get a patch that will be writable as a segy file. + Ensure coords are ("channel", "time"). + """ + # Ensure we have a single patch with coordinates time and distance. + spool = [spool] if isinstance(spool, dc.Patch) else spool + if len(spool) != 1: + msg = "Can only write a spool with as single patch as segy." + raise InvalidSpoolError(msg) + patch = spool[0] + dims = set(patch.dims) + has_distance_or_channel = dims & {"distance", "channel"} + if len(dims) != 2 or "time" not in dims or not has_distance_or_channel: + msg = ( + "Can only save 2D patches to SEGY with a time dimension and " + "either channel or distance dimensions." + ) + raise PatchError(msg) + # Currently we only support channels not distance dimension. + if "channel" not in dims: + patch = _get_patch_with_channel_coord(patch) + # Ensure there will be no loss in the time sampling. + # segy supports us precision + time_step = dc.to_float(patch.get_coord("time").step) + new_samp = np.round(time_step, 6) + round_error = np.abs(new_samp - time_step).max() + if round_error > round_error_max: + msg = ( + f"The segy format support us precision for temporal sampling. " + f"The input patch has a time step of {time_step} which will result " + "in a loss of precision. Either manually set the time step with " + "patch.update_coords or resample the time axis with patch.resample" + ) + raise PatchError(msg) + return patch.transpose("channel", "time") + + +def _make_time_header_dict(time_coord): + """Make the time header dict from a time coordinate.""" + header = {} + timestamp = pd.Timestamp(dc.to_datetime64(time_coord.min())) + time_step_ms = np.round(to_float(time_coord.step) * 1_000_000) + + segyio = optional_import("segyio") + trace_field = segyio.TraceField + + header[trace_field.YearDataRecorded] = timestamp.year + header[trace_field.DayOfYear] = timestamp.day_of_year + header[trace_field.HourOfDay] = timestamp.hour + header[trace_field.MinuteOfHour] = timestamp.minute + header[trace_field.SecondOfMinute] = timestamp.second + header[trace_field.TRACE_SAMPLE_INTERVAL] = int(time_step_ms) + header[trace_field.TRACE_SAMPLE_COUNT] = len(time_coord) + + return header + + +def _write_segy(spool, resource, version, segyio): + """ + Private function for writing a patch/spool as SEGY. + """ + patch = _get_segy_compatible_patch(spool) + time, channel = patch.get_coord("time"), patch.get_coord("channel") + chanel_step = channel.step + + time_dict = _make_time_header_dict(time) + bin_field = segyio.BinField + spec = segyio.spec() + + spec.format = 1 # 1 means float32 TODO look into supporting more + spec.samples = np.ones(len(time)) * len(channel) + spec.ilines = range(len(channel)) + spec.xlines = [1] + + # For 32 bit float for now. + data = patch.data.astype(np.float32) + + with segyio.create(resource, spec) as f: + # Update the file header info. + f.bin.update(tsort=segyio.TraceSortingFormat.INLINE_SORTING) + f.bin.update( + { + bin_field.Samples: time_dict[segyio.TraceField.TRACE_SAMPLE_COUNT], + bin_field.Interval: time_dict[segyio.TraceField.TRACE_SAMPLE_INTERVAL], + bin_field.SEGYRevision: int(version.split(".")[0]), + bin_field.SEGYRevisionMinor: int(version.split(".")[1]), + } + ) + # Then iterate each channel and dump to segy. + for num, data in enumerate(data): + header = dict(time_dict) + header.update( + { + segyio.su.offset: chanel_step, + segyio.su.iline: num, + segyio.su.xline: 1, + } + ) + f.header[num] = header + f.trace[num] = data diff --git a/dascore/utils/misc.py b/dascore/utils/misc.py index db71f0aa..266fc005 100644 --- a/dascore/utils/misc.py +++ b/dascore/utils/misc.py @@ -325,7 +325,7 @@ def optional_import(package_name: str) -> ModuleType: except ImportError: msg = ( f"{package_name} is not installed but is required for the " - f"requested functionality" + f"requested functionality." ) raise MissingOptionalDependencyError(msg) return mod diff --git a/dascore/utils/patch.py b/dascore/utils/patch.py index 59b6173c..12a21769 100644 --- a/dascore/utils/patch.py +++ b/dascore/utils/patch.py @@ -269,7 +269,7 @@ def _func(patch, *args, **kwargs): out = out.update_attrs(history=hist) return out - # attach original function. Although we want to encourage raw_function + # Attach original function. Although we want to encourage raw_function # for consistency with pydantic, we leave this to not break old code. _func.func = getattr(func, "raw_function", func) # matches pydantic naming. diff --git a/pyproject.toml b/pyproject.toml index 8edafcf7..33ce54be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,6 @@ dependencies = [ "tables>=3.7", "typing_extensions", "pint", - "segyio", ] [project.optional-dependencies] @@ -64,6 +63,7 @@ extras = [ "findiff", "obspy", "numba", + "segyio", ] docs = [ @@ -119,7 +119,9 @@ TERRA15__V4 = "dascore.io.terra15.core:Terra15FormatterV4" TERRA15__V5 = "dascore.io.terra15.core:Terra15FormatterV5" TERRA15__V6 = "dascore.io.terra15.core:Terra15FormatterV6" SILIXA_H5__V1 = "dascore.io.silixah5:SilixaH5V1" -SEGY__V2 = "dascore.io.segy.core:SegyV2" +SEGY__V1_0 = "dascore.io.segy.core:SegyV1_0" +SEGY__V2_0 = "dascore.io.segy.core:SegyV2_0" +SEGY__V2_1 = "dascore.io.segy.core:SegyV2_1" RSF__V1 = "dascore.io.rsf.core:RSFV1" WAV = "dascore.io.wav.core:WavIO" XMLBINARY__V1 = "dascore.io.xml_binary.core:XMLBinaryV1" diff --git a/tests/test_core/test_spool.py b/tests/test_core/test_spool.py index d5764f2e..d4b85585 100644 --- a/tests/test_core/test_spool.py +++ b/tests/test_core/test_spool.py @@ -1,8 +1,9 @@ -"""Test for spool functions.""" +"""Tests for spool function.""" from __future__ import annotations import copy +import shutil from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor import numpy as np @@ -14,8 +15,10 @@ from dascore.core.spool import BaseSpool, MemorySpool from dascore.exceptions import ( InvalidSpoolError, + MissingOptionalDependencyError, ParameterError, ) +from dascore.utils.downloader import fetch from dascore.utils.time import to_datetime64, to_timedelta64 @@ -497,6 +500,64 @@ def test_file_spool(self, random_spool, tmp_path_factory): assert isinstance(pickle_spool, MemorySpool) +class TestSpoolBehaviorOptionalImports: + """ + Tests for spool behavior when handling optional formats which require + optional dependencies. + + Essentially, if the spool is specific to the file (eg spool("file")) + it should raise. If it is applied on a directory with such files + (eg spool("directory/with/bad/files")) it should give a warning. + """ + + # The string to match against the warning/error. + _msg = "found files that can be read if additional" + + @pytest.fixture(scope="function", autouse=True) + def monkey_patch_segy(self, monkeypatch): + """Monkey patch the name of the imported library for segy.""" + # TODO we should find a cleaner way to do this in the future. + from dascore.io.segy import SegyV1_0 + + monkeypatch.setattr(SegyV1_0, "_package_name", "not_segyio_clearly") + + @pytest.fixture(scope="class") + def segy_file_path(self, tmp_path_factory): + """ + Create a directory structure like this: + + optional_import_test + - h5_simple_1.h5 + - segy_only + - small_channel_patch.sgy + """ + dir_path = tmp_path_factory.mktemp("optional_import_test") + simple_path = fetch("h5_simple_1.h5") + shutil.copy(simple_path, dir_path) + + segy_only_path = dir_path / "segy_only" + segy_only_path.mkdir(exist_ok=True, parents=True) + segy_path = fetch("small_channel_patch.sgy") + shutil.copy(segy_path, segy_only_path) + return segy_only_path / segy_path.name + + def test_spool_on_directory_no_other_files(self, segy_file_path): + """Ensure a directory with no other readable files raises.""" + with pytest.raises(MissingOptionalDependencyError, match=self._msg): + dc.spool(segy_file_path.parent).update() + + def test_spool_on_single_file(self, segy_file_path): + """Ensure a single file also raises.""" + with pytest.raises(MissingOptionalDependencyError, match=self._msg): + dc.spool(segy_file_path).update() + + def test_spool_on_multiple_files(self, segy_file_path): + """Ensure if other files exist the warning is issued.""" + top_level = segy_file_path.parent.parent + with pytest.warns(UserWarning, match=self._msg): + dc.spool(top_level).update() + + class TestMisc: """Tests for misc. spool cases.""" diff --git a/tests/test_io/test_common_io.py b/tests/test_io/test_common_io.py index edac00b1..52bc2b21 100644 --- a/tests/test_io/test_common_io.py +++ b/tests/test_io/test_common_io.py @@ -10,7 +10,7 @@ from __future__ import annotations -from contextlib import suppress +from contextlib import contextmanager, suppress from functools import cache from io import BytesIO from operator import eq, ge, le @@ -21,6 +21,7 @@ import pytest import dascore as dc +from dascore.exceptions import MissingOptionalDependencyError from dascore.io import BinaryReader from dascore.io.ap_sensing import APSensingV10 from dascore.io.dasdae import DASDAEV1 @@ -31,7 +32,7 @@ from dascore.io.optodas import OptoDASV8 from dascore.io.pickle import PickleIO from dascore.io.prodml import ProdMLV2_0, ProdMLV2_1 -from dascore.io.segy import SegyV2 +from dascore.io.segy import SegyV1_0 from dascore.io.sentek import SentekV5 from dascore.io.silixah5 import SilixaH5V1 from dascore.io.tdms import TDMSFormatterV4713 @@ -75,7 +76,7 @@ Terra15FormatterV5(): ("terra15_v5_test_file.hdf5",), Terra15FormatterV6(): ("terra15_v6_test_file.hdf5",), Terra15FormatterV6(): ("terra15_v6_test_file.hdf5",), - SegyV2(): ("conoco_segy_1.sgy",), + SegyV1_0(): ("conoco_segy_1.sgy",), DASHDF5(): ("PoroTomo_iDAS_1.h5",), SentekV5(): ("DASDMSShot00_20230328155653619.das",), } @@ -83,7 +84,10 @@ # This tuple is for fiber io which support a write method and can write # generic patches. If the patch has to be in some special form, for example # only flat patches can be written to WAV, don't put it here. -COMMON_IO_WRITE_TESTS = (PickleIO(), DASDAEV1()) +COMMON_IO_WRITE_TESTS = ( + PickleIO(), + DASDAEV1(), +) # Specifies data registry entries which should not be tested. SKIP_DATA_FILES = {"whale_1.hdf5", "brady_hs_DAS_DTS_coords.csv"} @@ -96,8 +100,12 @@ def _cached_read(path, io=None): This ensures each files is read at most twice. """ if io is None: - return dc.read(path) - return io.read(path) + read = dc.read + else: + read = io.read + with skip_missing_dependency(): + out = read(path) + return out def _get_flat_io_test(): @@ -109,6 +117,15 @@ def _get_flat_io_test(): return flat_io +@contextmanager +def skip_missing_dependency(): + """Skip if missing dependencies found.""" + try: + yield + except MissingOptionalDependencyError: + pytest.skip("Missing optional dep to read file.") + + @pytest.fixture(scope="session", params=list(COMMON_IO_READ_TESTS)) def io_instance(request): """Fixture for returning fiber io instances.""" @@ -138,14 +155,16 @@ def data_file_path(request): @pytest.fixture(scope="session") def read_spool(data_file_path): """Read each file into a spool.""" - out = dc.read(data_file_path) + with skip_missing_dependency(): + out = dc.read(data_file_path) return out @pytest.fixture(scope="session") def scanned_attrs(data_file_path): """Read each file into a spool.""" - out = dc.scan(data_file_path) + with skip_missing_dependency(): + out = dc.scan(data_file_path) return out @@ -270,7 +289,8 @@ def test_slice_single_dim_both_ends(self, io_path_tuple): a patch containing the requested data is returned. """ io, path = io_path_tuple - attrs_from_file = dc.scan(path) + with skip_missing_dependency(): + attrs_from_file = dc.scan(path) assert len(attrs_from_file) # skip files that have more than one patch for now # TODO just write better test logic to handle this case. @@ -317,7 +337,8 @@ class TestScan: def test_scan_basics(self, data_file_path): """Ensure each file can be scanned.""" - attrs_list = dc.scan(data_file_path) + with skip_missing_dependency(): + attrs_list = dc.scan(data_file_path) assert len(attrs_list) for attrs in attrs_list: @@ -327,7 +348,8 @@ def test_scan_basics(self, data_file_path): def test_scan_has_version_and_format(self, io_path_tuple): """Scan output should contain version and format.""" io, path = io_path_tuple - attr_list = io.scan(path) + with skip_missing_dependency(): + attr_list = io.scan(path) for attrs in attr_list: assert attrs.file_format == io.name assert attrs.file_version == io.version @@ -396,7 +418,8 @@ def test_scan_attrs_match_patch_attrs(self, data_file_path): "tag", "network", ) - scan_attrs_list = dc.scan(data_file_path) + with skip_missing_dependency(): + scan_attrs_list = dc.scan(data_file_path) patch_attrs_list = [x.attrs for x in _cached_read(data_file_path)] assert len(scan_attrs_list) == len(patch_attrs_list) for pat_attrs1, scan_attrs2 in zip(patch_attrs_list, scan_attrs_list): diff --git a/tests/test_io/test_segy/test_segy.py b/tests/test_io/test_segy/test_segy.py index 42513f48..ed06d0f8 100644 --- a/tests/test_io/test_segy/test_segy.py +++ b/tests/test_io/test_segy/test_segy.py @@ -1 +1,112 @@ """Tests for SEGY format.""" + +import numpy as np +import pytest + +import dascore as dc +from dascore.exceptions import ( + InvalidSpoolError, + PatchError, +) +from dascore.io.segy.core import SegyV1_0 +from dascore.utils.misc import suppress_warnings + + +class TestSegyGetFormat: + """Tests for getting format codes of SEGY files.""" + + @pytest.fixture(scope="class") + def small_file(self, tmp_path_factory): + """Creates a small file with only a few bytes.""" + parent = tmp_path_factory.mktemp("small_file") + path = parent / "test_file.segy" + with path.open("wb") as f: + f.write(b"abd") + return path + + def test_get_formate_small_file(self, small_file): + """ + Ensure a file that is too small to contain segy header doesn't throw + an error. + """ + segy = SegyV1_0() + out = segy.get_format(small_file) + assert out is False # we actually want to make sure its False. + + +class TestSegyWrite: + """Tests for writing segy files.""" + + @pytest.fixture(scope="class") + def test_segy_directory(self, tmp_path_factory): + """Make a tmp directory for saving files.""" + path = tmp_path_factory.mktemp("test_segy_write_directory") + return path + + @pytest.fixture(scope="class") + def channel_patch(self, random_patch): + """Get a patch with channels rather than distance.""" + distance = random_patch.get_coord("distance") + new = random_patch.rename_coords(distance="channel").update_coords( + **{"channel": np.arange(len(distance))} + ) + return new + + @pytest.fixture(scope="class") + def channel_patch_path(self, channel_patch, test_segy_directory): + """Write the channel patch to disk.""" + pytest.importorskip("segyio") + path = test_segy_directory / "patch_with_channel_coord.segy" + channel_patch.io.write(path, "segy") + return path + + def test_can_get_format(self, channel_patch_path): + """Ensure we can get the correct format/version.""" + segy = SegyV1_0() + out = segy.get_format(channel_patch_path) + assert out, "Failed to detect written segy file." + assert out[0] == segy.name + + def test_channel_patch_round_trip(self, channel_patch_path, channel_patch): + """The channel patch should round trip.""" + patch1 = channel_patch + patch2 = dc.spool(channel_patch_path)[0].transpose(*patch1.dims) + # We really don't have a way to transport attributes yet, so we + # just check that data and coords are equal. + assert np.allclose(patch1.data, patch2.data) + assert patch1.coords == patch2.coords + + def test_write_non_channel_path(self, random_patch, tmp_path_factory): + """Ensure a 'normal' patch can be written.""" + pytest.importorskip("segyio") + path = tmp_path_factory.mktemp("test_write_segy") / "temppath.segy" + match = "non-time dimension" + with pytest.warns(match=match): + random_patch.io.write(path, "segy") + assert path.exists() + patch2 = dc.spool(path)[0] + assert set(random_patch.shape) == set(patch2.shape) + + def test_loss_of_precision_raises(self, random_patch, tmp_path_factory): + """Ensure that loss of precision raises a PatchError.""" + pytest.importorskip("segyio") + path = tmp_path_factory.mktemp("test_loss_of_precision") / "temppath.segy" + patch = random_patch.update_coords(time_step=np.timedelta64(10, "ns")) + match = "will result in a loss of precision" + with pytest.raises(PatchError, match=match): + with suppress_warnings(): + patch.io.write(path, "segy") + + def test_bad_dims_raises(self, random_patch, tmp_path): + """Ensure a bad dimension name raises.""" + pytest.importorskip("segyio") + patch = random_patch.rename_coords(distance="bad_dim") + with pytest.raises(PatchError, match="Can only save 2D patches"): + patch.io.write(tmp_path, "segy") + + def test_multi_patch_spool_raises(self, random_spool, tmp_path): + """Spools with more than one patch cant be written.""" + pytest.importorskip("segyio") + segy = SegyV1_0() + with pytest.raises(InvalidSpoolError): + segy.write(random_spool, tmp_path)