diff --git a/docs/api/index.md b/docs/api/index.md index 6a1c662..3f2d665 100644 --- a/docs/api/index.md +++ b/docs/api/index.md @@ -4,11 +4,12 @@ :maxdepth: 1 xdas +atoms +io fft -signal -processing parallel +processing +signal synthetics virtual -atoms ``` \ No newline at end of file diff --git a/docs/api/io.md b/docs/api/io.md new file mode 100644 index 0000000..4e2836c --- /dev/null +++ b/docs/api/io.md @@ -0,0 +1,27 @@ +```{eval-rst} +.. currentmodule:: xdas.io +``` + +# xdas.io + +```{eval-rst} +.. autosummary:: + :toctree: ../_autosummary + + get_free_port +``` + +```{eval-rst} +.. currentmodule:: xdas.io.asn +``` + + +## ASN + +```{eval-rst} +.. autosummary:: + :toctree: ../_autosummary + + ZMQPublisher + ZMQSubscriber +``` diff --git a/docs/api/processing.md b/docs/api/processing.md index fc6d244..11ed302 100644 --- a/docs/api/processing.md +++ b/docs/api/processing.md @@ -12,4 +12,6 @@ DataArrayLoader RealTimeLoader DataArrayWriter + ZMQPublisher + ZMQSubscriber ``` \ No newline at end of file diff --git a/docs/user-guide/index.md b/docs/user-guide/index.md index a5bdbd7..c8aef43 100644 --- a/docs/user-guide/index.md +++ b/docs/user-guide/index.md @@ -10,4 +10,5 @@ interpolated-coordinates convert-displacement atoms processing +streaming ``` \ No newline at end of file diff --git a/docs/user-guide/streaming.md b/docs/user-guide/streaming.md new file mode 100644 index 0000000..b582eb6 --- /dev/null +++ b/docs/user-guide/streaming.md @@ -0,0 +1,91 @@ +--- +file_format: mystnb +kernelspec: + name: python3 +--- + +# Streaming data + +Xdas allows to stream data over any network using [ZeroMQ](https://zeromq.org). Xdas use the Publisher and Subscriber patterns meaning that on one node the data is published and that any number of subscribers can receive the data stream. + +Streaming data with Xdas is done by simply dumping each chunk to NetCDF binaries and to send those as packets. This ensure that each packet is self described and that feature such as compression are available (which can be very helpful to minimize the used bandwidth). + +Xdas implements the {py:class}`~xdas.processing.ZMQPublisher` and {py:class}`~xdas.processing.ZMQSubscriber`.Those object can respectively be used as a Writer and a Loader as described in the [](processing) section. Both are initialized by giving an network address. The publisher use the `submit` method to send packets while the subscriber is an infinite iterator that yields packets. + +In this section, we will mimic the use of several machine by using multithreading, where each thread is supposed to be a different machine. In real-life application, the publisher and subscriber are generally called in different machine or software. + +## Simple use case + +```{code-cell} +import threading +import time + +import xdas as xd +from xdas.processing import ZMQPublisher, ZMQSubscriber +``` + +First we generate some data and split it into packets + +```{code-cell} +da = xd.synthetics.dummy() +packets = xd.split(da, 5) +``` + +We then publish the packets on machine 1. + +```{code-cell} +address = f"tcp://localhost:{xd.io.get_free_port()}" +publisher = ZMQPublisher(address) + +def publish(): + for packet in packets: + publisher.submit(packet) + # give a chance to the subscriber to connect in time and to get the last packet + time.sleep(0.1) + +machine1 = threading.Thread(target=publish) +machine1.start() +``` + +Let's receive the packets on machine 2. + +```{code-cell} +subscriber = ZMQSubscriber(address) + +packets = [] + +def subscribe(): + for packet in subscriber: + packets.append(packet) + +machine2 = threading.Thread(target=subscribe) +machine2.start() +``` + +Now we wait for machine 1 to finish sending its packet and see if everything went well. + +```{code-cell} +machine1.join() +print(f"We received {len(packets)} packets!") +assert xd.concatenate(packets).equals(da) +``` + +## Using encoding + +To reduce the volume of the transmitted data, compression is often useful. Xdas enable the use of the ZFP algorithm when storing data but also when streaming it. Encoding is declared the same way. + +```{code-cell} +:tags: [remove-output] + +import hdf5plugin + +address = f"tcp://localhost:{xd.io.get_free_port()}" +encoding = {"chunks": (10, 10), **hdf5plugin.Zfp(accuracy=1e-6)} +publisher = ZMQPublisher(address, encoding) # Add encoding here, the rest is the same +``` + +{py:class}`~xdas.io.asn.ZMQSubscriber` + +```{note} +Xdas also implements the ZeroMQ protocol used by the OptoDAS interrogators by ASN. Equivalent {py:class}`~xdas.io.asn.ZMQPublisher` and {py:class}`~xdas.io.asn.ZMQSubscriber` can be found in {py:mod}`xdas.io.asn`. This can be useful get data in real-time from one instrument of that kind. Note that compression is not available with that protocol yet. +``` diff --git a/pyproject.toml b/pyproject.toml index 1130f72..c794588 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "watchdog", "xarray", "xinterp", + "pyzmq", ] [project.optional-dependencies] diff --git a/tests/io/test_asn.py b/tests/io/test_asn.py index 4640904..05f1ea5 100644 --- a/tests/io/test_asn.py +++ b/tests/io/test_asn.py @@ -1 +1,240 @@ -# TODO +import json +import socket +import threading +import time + +import numpy as np +import zmq + +import xdas as xd +from xdas.io.asn import ZMQPublisher, ZMQSubscriber + + +def get_free_local_address(): + port = xd.io.get_free_port() + return f"tcp://localhost:{port}" + + +coords = { + "time": { + "tie_indices": [0, 99], + "tie_values": [ + np.datetime64("2020-01-01T00:00:00.000"), + np.datetime64("2020-01-01T00:00:09.900"), + ], + }, + "distance": {"tie_indices": [0, 9], "tie_values": [0.0, 90.0]}, +} + +da_float32 = xd.DataArray( + np.random.randn(100, 10).astype("float32"), + coords, +) + +da_int16 = xd.DataArray( + np.random.randn(100, 10).astype("int16"), + coords, +) + + +class TestZMQPublisher: + def test_get_header(self): + header = ZMQPublisher._get_header(da_float32) + assert header["bytesPerPackage"] == 40 + assert header["nPackagesPerMessage"] == 100 + assert header["nChannels"] == 10 + assert header["dataType"] == "float" + assert header["dx"] == 10.0 + assert header["dt"] == 0.1 + assert header["dtUnit"] == "s" + assert header["dxUnit"] == "m" + assert header["roiTable"] == [{"roiStart": 0, "roiEnd": 9, "roiDec": 1}] + header = ZMQPublisher._get_header(da_int16) + assert header["dataType"] == "short" + + def test_init_conect_set_header(self): + address = get_free_local_address() + pub = ZMQPublisher(address) + pub.submit(da_float32) + assert pub.header == ZMQPublisher._get_header(da_float32) + + def test_send_header(self): + address = get_free_local_address() + pub = ZMQPublisher(address) + pub.submit(da_float32) + socket = self.get_socket(address) + pub.submit(da_float32) # a packet must be sent once subscriber is connected + assert socket.recv() == json.dumps(pub.header).encode("utf-8") + + def test_send_data(self): + address = get_free_local_address() + pub = ZMQPublisher(address) + pub.submit(da_float32) + socket = self.get_socket(address) + pub.submit(da_float32) # a packet must be sent once subscriber is connected + socket.recv() # header + message = socket.recv() + assert message[:8] == da_float32["time"][0].values.astype("M8[ns]").tobytes() + assert message[8:] == da_float32.data.tobytes() + pub.submit(da_int16) + socket.recv() # header + message = socket.recv() + assert message[:8] == da_int16["time"][0].values.astype("M8[ns]").tobytes() + assert message[8:] == da_int16.data.tobytes() + + def test_send_chunks(self): + address = get_free_local_address() + pub = ZMQPublisher(address) + chunks = xd.split(da_float32, 10) + pub.submit(chunks[0]) + time.sleep(0.001) + socket = self.get_socket(address) + for chunk in chunks[1:]: + pub.submit(chunk) + time.sleep(0.001) + assert socket.recv() == json.dumps(pub.header).encode("utf-8") + for chunk in chunks[1:]: # first was sent before subscriber connected + message = socket.recv() + assert message[:8] == chunk["time"][0].values.astype("M8[ns]").tobytes() + assert message[8:] == chunk.data.tobytes() + + def test_several_subscribers(self): + address = get_free_local_address() + pub = ZMQPublisher(address) + chunks = xd.split(da_float32, 10) + pub.submit(chunks[0]) + time.sleep(0.001) + socket1 = self.get_socket(address) + for chunk in chunks[1:5]: + pub.submit(chunk) + time.sleep(0.001) + socket2 = self.get_socket(address) + for chunk in chunks[5:]: + pub.submit(chunk) + time.sleep(0.001) + assert socket1.recv() == json.dumps(pub.header).encode("utf-8") + for chunk in chunks[1:]: # first was sent before subscriber connected + message = socket1.recv() + assert message[:8] == chunk["time"][0].values.astype("M8[ns]").tobytes() + assert message[8:] == chunk.data.tobytes() + assert socket2.recv() == json.dumps(pub.header).encode("utf-8") + for chunk in chunks[5:]: # first was sent before subscriber connected + message = socket2.recv() + assert message[:8] == chunk["time"][0].values.astype("M8[ns]").tobytes() + assert message[8:] == chunk.data.tobytes() + + def test_change_header(self): + address = get_free_local_address() + pub = ZMQPublisher(address) + chunks = xd.split(da_float32, 10) + pub.submit(chunks[0]) + time.sleep(0.001) + socket = self.get_socket(address) + for chunk in chunks[1:5]: + pub.submit(chunk) + header1 = pub.header + time.sleep(0.001) + for chunk in chunks[5:]: + pub.submit(chunk.isel(distance=slice(0, 5))) + header2 = pub.header + time.sleep(0.001) + assert socket.recv() == json.dumps(header1).encode("utf-8") + for chunk in chunks[1:5]: # first was sent before subscriber connected + message = socket.recv() + assert message[:8] == chunk["time"][0].values.astype("M8[ns]").tobytes() + assert message[8:] == chunk.data.tobytes() + assert socket.recv() == json.dumps(header2).encode("utf-8") + for chunk in chunks[5:]: # first was sent before subscriber connected + message = socket.recv() + assert message[:8] == chunk["time"][0].values.astype("M8[ns]").tobytes() + assert message[8:] == chunk.isel(distance=slice(0, 5)).data.tobytes() + + def get_socket(self, address): + socket = zmq.Context().socket(zmq.SUB) + socket.connect(address) + socket.setsockopt(zmq.SUBSCRIBE, b"") + time.sleep(0.001) + return socket + + +class TestZMQSubscriber: + def test_one_chunk(self): + address = get_free_local_address() + pub = ZMQPublisher(address) + chunks = [da_float32] + threading.Thread(target=self.publish, args=(pub, chunks)).start() + sub = ZMQSubscriber(address) + assert sub.address == address + assert sub.packet_size == 4008 + assert sub.shape == (100, 10) + assert sub.dtype == np.float32 + assert sub.distance == {"tie_indices": [0, 9], "tie_values": [0.0, 90.0]} + assert sub.delta == np.timedelta64(100, "ms") + result = next(sub) + assert result.equals(da_float32) + chunks = [da_int16] + threading.Thread(target=self.publish, args=(pub, chunks)).start() + result = next(sub) + assert sub.packet_size == 2008 + assert sub.dtype == np.int16 + assert result.equals(da_int16) + + def test_several_chunks(self): + address = get_free_local_address() + pub = ZMQPublisher(address) + chunks = xd.split(da_float32, 5) + threading.Thread(target=self.publish, args=(pub, chunks)).start() + sub = ZMQSubscriber(address) + assert sub.packet_size == 808 + assert sub.shape == (20, 10) + assert sub.dtype == np.float32 + assert sub.distance == {"tie_indices": [0, 9], "tie_values": [0.0, 90.0]} + assert sub.delta == np.timedelta64(100, "ms") + for chunk in chunks: + result = next(sub) + assert result.equals(chunk) + + def test_several_subscribers(self): + address = get_free_local_address() + pub = ZMQPublisher(address) + chunks = xd.split(da_float32, 5) + thread = threading.Thread(target=self.publish, args=(pub, chunks[:2])) + thread.start() + sub1 = ZMQSubscriber(address) + thread.join() + thread = threading.Thread(target=self.publish, args=(pub, chunks[2:])) + thread.start() + sub2 = ZMQSubscriber(address) + + for chunk in chunks: + result = next(sub1) + assert result.equals(chunk) + for chunk in chunks[2:]: + result = next(sub2) + assert result.equals(chunk) + + def test_change_header(self): + address = get_free_local_address() + pub = ZMQPublisher(address) + chunks = xd.split(da_float32, 5) + chunks = [chunk.isel(distance=slice(0, 5)) for chunk in chunks[:2]] + chunks[2:] + threading.Thread(target=self.publish, args=(pub, chunks)).start() + sub = ZMQSubscriber(address) + for chunk in chunks: + result = next(sub) + assert result.equals(chunk) + + def test_iter(self): + address = get_free_local_address() + pub = ZMQPublisher(address) + chunks = xd.split(da_float32, 5) + threading.Thread(target=self.publish, args=(pub, chunks)).start() + sub = ZMQSubscriber(address) + sub = (chunk for _, chunk in zip(range(5), sub)) + result = xd.concatenate([chunk for chunk in sub]) + assert result.equals(da_float32) + + def publish(self, pub, chunks): + for chunk in chunks: + time.sleep(0.001) + pub.submit(chunk) diff --git a/tests/test_processing.py b/tests/test_processing.py index 3434d5d..820a908 100644 --- a/tests/test_processing.py +++ b/tests/test_processing.py @@ -1,6 +1,10 @@ import os import tempfile +import threading +import time +import hdf5plugin +import numpy as np import pandas as pd import scipy.signal as sp @@ -10,6 +14,8 @@ DataArrayLoader, DataArrayWriter, DataFrameWriter, + ZMQPublisher, + ZMQSubscriber, process, ) from xdas.signal import sosfilt @@ -160,3 +166,42 @@ def test_write_and_result_with_existing_file(self): # Check if the output file contains the correct data output_df = pd.read_csv(writer.path) assert output_df.equals(expected_result) + + +class TestZMQ: + def _publish_and_subscribe(self, packets, address, encoding=None): + publisher = ZMQPublisher(address, encoding) + + def publish(): + for packet in packets: + time.sleep(0.001) + publisher.submit(packet) + + threading.Thread(target=publish).start() + + subscriber = ZMQSubscriber(address) + result = [] + for n, packet in enumerate(subscriber, start=1): + result.append(packet) + if n == len(packets): + break + return xdas.concatenate(result) + + def test_publish_and_subscribe(self): + expected = xdas.synthetics.dummy() + packets = xdas.split(expected, 10) + address = f"tcp://localhost:{xdas.io.get_free_port()}" + + result = self._publish_and_subscribe(packets, address) + assert result.equals(expected) + + def test_encoding(self): + expected = xdas.synthetics.dummy() + packets = xdas.split(expected, 10) + address = f"tcp://localhost:{xdas.io.get_free_port()}" + encoding = {"chunks": (10, 10), **hdf5plugin.Zfp(accuracy=1e-6)} + + result = self._publish_and_subscribe(packets, address, encoding=encoding) + assert np.allclose(result.values, expected.values, atol=1e-6) + result.data = expected.data + assert result.equals(expected) diff --git a/xdas/io/__init__.py b/xdas/io/__init__.py index b6fb426..008cd21 100644 --- a/xdas/io/__init__.py +++ b/xdas/io/__init__.py @@ -1 +1,2 @@ from . import asn, febus, optasense, sintela, terra15 +from .core import get_free_port diff --git a/xdas/io/asn.py b/xdas/io/asn.py index dc6f8a2..14efe06 100644 --- a/xdas/io/asn.py +++ b/xdas/io/asn.py @@ -1,5 +1,10 @@ +import json + import h5py import numpy as np +import zmq + +from xdas.core.coordinates import get_sampling_interval from ..core.dataarray import DataArray from ..virtual import VirtualSource @@ -16,3 +21,237 @@ def read(fname): time = {"tie_indices": [0, nt - 1], "tie_values": [t0, t0 + (nt - 1) * dt]} distance = {"tie_indices": [0, nx - 1], "tie_values": [0.0, (nx - 1) * dx]} return DataArray(data, {"time": time, "distance": distance}) + + +type_map = { + "short": np.int16, + "int": np.int32, + "long": np.int64, + "float": np.float32, + "double": np.float64, +} + + +class ZMQSubscriber: + def __init__(self, address): + """ + Initializes a ZMQStream object. + + Parameters + ---------- + address : str + The address to connect to. + + Examples + -------- + >>> import time + >>> import threading + + >>> import xdas as xd + >>> from xdas.io.asn import ZMQSubscriber + + >>> port = xd.io.get_free_port() + >>> address = f"tcp://localhost:{port}" + >>> publisher = ZMQPublisher(address) + + >>> da = xd.synthetics.dummy() + >>> chunks = xd.split(da, 10) + + >>> def publish(): + ... for chunk in chunks: + ... time.sleep(0.001) # so that the subscriber can connect in time + ... publisher.submit(chunk) + >>> threading.Thread(target=publish).start() + + >>> subscriber = ZMQSubscriber(address) + >>> for nchunk in range(10): + ... chunk = next(subscriber) + ... # do something with the chunk + + """ + self.address = address + self._connect(self.address) + message = self._get_message() + self._update_header(message) + + def __iter__(self): + return self + + def __next__(self): + message = self._get_message() + if not self._is_packet(message): + self._update_header(message) + return self.__next__() + else: + return self._unpack(message) + + def _connect(self, address): + context = zmq.Context() + socket = context.socket(zmq.SUB) + socket.connect(address) + socket.setsockopt_string(zmq.SUBSCRIBE, "") + self._socket = socket + + def _get_message(self): + return self._socket.recv() + + def _is_packet(self, message): + return len(message) == self.packet_size + + def _update_header(self, message): + header = json.loads(message.decode("utf-8")) + self.packet_size = 8 + header["bytesPerPackage"] * header["nPackagesPerMessage"] + self.shape = (header["nPackagesPerMessage"], header["nChannels"]) + self.dtype = type_map[header["dataType"]] + roiTable = header["roiTable"][0] + di = roiTable["roiStart"] * header["dx"] + de = roiTable["roiEnd"] * header["dx"] + self.distance = { + "tie_indices": [0, header["nChannels"] - 1], + "tie_values": [di, de], + } + self.delta = float_to_timedelta(header["dt"], header["dtUnit"]) + + def _unpack(self, message): + t0 = np.frombuffer(message[:8], "datetime64[ns]").reshape(()) + data = np.frombuffer(message[8:], self.dtype).reshape(self.shape) + time = { + "tie_indices": [0, self.shape[0] - 1], + "tie_values": [t0, t0 + (self.shape[0] - 1) * self.delta], + } + return DataArray(data, {"time": time, "distance": self.distance}) + + +class ZMQPublisher: + """ + A class to stream data using ZeroMQ. + + Parameters + ---------- + address : str + The address to bind the ZeroMQ socket. + + Attributes + ---------- + address : str + The address where the ZeroMQ is bound to. + + Methods + ------- + submit(da) + Submits the data array for publishing. + + Examples + -------- + >>> import xdas as xd + >>> from xdas.io.asn import ZMQPublisher + + >>> da = xd.synthetics.dummy() + + >>> port = xd.io.get_free_port() + >>> address = f"tcp://localhost:{port}" + >>> publisher = ZMQPublisher(address) + >>> chunks = xd.split(da, 10) + >>> for chunk in chunks: + ... publisher.submit(chunk) + + """ + + def __init__(self, address): + self.address = address + self._connect(address) + self._header = None + + @property + def header(self): + return self._header + + @header.setter + def header(self, header): + self._header = header + self.socket.setsockopt(zmq.XPUB_WELCOME_MSG, json.dumps(header).encode("utf-8")) + + def submit(self, da): + self._send(da) + + def write(self, da): + self._send(da) + + def _connect(self, address): + context = zmq.Context() + socket = context.socket(zmq.XPUB) + socket.setsockopt(zmq.XPUB_VERBOSE, True) + socket.bind(address) + self.socket = socket + + @staticmethod + def _get_header(da): + da = da.transpose("time", "distance") + header = { + "bytesPerPackage": da.dtype.itemsize * da.shape[1], + "nPackagesPerMessage": da.shape[0], + "nChannels": da.shape[1], + "dataType": next((k for k, v in type_map.items() if v == da.dtype), None), + "dx": get_sampling_interval(da, "distance"), + "dt": get_sampling_interval(da, "time"), + "dtUnit": "s", + "dxUnit": "m", + "roiTable": [{"roiStart": 0, "roiEnd": da.shape[1] - 1, "roiDec": 1}], + } + return header + + def _send(self, da): + da = da.transpose("time", "distance") + header = self._get_header(da) + if self.header is None: + self.header = header + if not header == self.header: + self.header = header + self._send_header() + self._send_data(da) + + def _send_header(self): + message = json.dumps(self.header).encode("utf-8") + self._send_message(message) + + def _send_data(self, da): + da = da.transpose("time", "distance") + t0 = da["time"][0].values.astype("datetime64[ns]") + data = da.values + message = t0.tobytes() + data.tobytes() + self._send_message(message) + + def _send_message(self, message): + self.socket.send(message) + + +def float_to_timedelta(value, unit): + """ + Converts a floating-point value to a timedelta object. + + Parameters + ---------- + value : float + The value to be converted. + unit : str + The unit of the value. Valid units are 'ns' (nanoseconds), 'us' (microseconds), + 'ms' (milliseconds), and 's' (seconds). + + Returns + ------- + timedelta + The converted timedelta object. + + Example + ------- + >>> float_to_timedelta(1.5, 'ms') + numpy.timedelta64(1500000,'ns') + """ + conversion_factors = { + "ns": 1e0, + "us": 1e3, + "ms": 1e6, + "s": 1e9, + } + conversion_factor = conversion_factors[unit] + return np.timedelta64(round(value * conversion_factor), "ns") diff --git a/xdas/io/core.py b/xdas/io/core.py new file mode 100644 index 0000000..a4171a7 --- /dev/null +++ b/xdas/io/core.py @@ -0,0 +1,20 @@ +import socket + + +def get_free_port(): + """ + Find and return a free port on the host machine. + + This function creates a temporary socket, binds it to an available port + provided by the host, retrieves the port number, and then closes the socket. + This is useful for finding an available port for network communication. + + Returns + ------- + int: + A free port number on the host machine. + + """ + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] diff --git a/xdas/processing/__init__.py b/xdas/processing/__init__.py index f427119..50e16f3 100644 --- a/xdas/processing/__init__.py +++ b/xdas/processing/__init__.py @@ -4,5 +4,7 @@ DataArrayWriter, DataFrameWriter, RealTimeLoader, + ZMQPublisher, + ZMQSubscriber, process, ) diff --git a/xdas/processing/core.py b/xdas/processing/core.py index a8123db..c3c42a5 100644 --- a/xdas/processing/core.py +++ b/xdas/processing/core.py @@ -1,9 +1,11 @@ import os from concurrent.futures import ThreadPoolExecutor from queue import Queue +from tempfile import TemporaryDirectory import numpy as np import pandas as pd +import zmq from watchdog.events import FileSystemEventHandler from watchdog.observers import Observer @@ -168,6 +170,8 @@ class DataArrayWriter: dirpath : str or path The directory to store the output of a processing pipeline. The directory needs to exist and be empty. + encoding : dict + The encoding to use when dumping the DataArrays to bytes. Examples -------- @@ -297,3 +301,151 @@ def result(self): except pd.errors.EmptyDataError: out = pd.DataFrame() return out + + +class ZMQPublisher: + """ + A class for publishing DataArray chunks over ZeroMQ. + + Parameters + ---------- + address : str + The address to bind the publisher to. + encoding : dict + The encoding to use when dumping the DataArrays to bytes. + + Examples + -------- + >>> import xdas as xd + >>> from xdas.processing import ZMQPublisher, ZMQSubscriber + + First we generate some data and split it into packets + + >>> packets = xd.split(xd.synthetics.dummy(), 10) + + We initialize the publisher at a given address + + >>> address = f"tcp://localhost:{xd.io.get_free_port()}" + >>> publisher = ZMQPublisher(address) + + We can then publish the packets + + >>> for da in packets: + ... publisher.submit(da) + + To reduce the size of the packets, we can also specify an encoding + + >>> import hdf5plugin + + >>> address = f"tcp://localhost:{xd.io.get_free_port()}" + >>> encoding = {"chunks": (10, 10), **hdf5plugin.Zfp(accuracy=1e-6)} + >>> publisher = ZMQPublisher(address, encoding) + >>> for da in packets: + ... publisher.submit(da) + + """ + + def __init__(self, address, encoding=None): + self.address = address + self.encoding = encoding + self._context = zmq.Context() + self._socket = self._context.socket(zmq.PUB) + self._socket.bind(self.address) + + def submit(self, da): + """ + Send a DataArray over ZeroMQ. + + Parameters + ---------- + da : DataArray + The DataArray to be sent. + + """ + self._socket.send(tobytes(da, self.encoding)) + + def write(self, da): + self.submit(da) + + def result(): + return None + + +class ZMQSubscriber: + """ + A class for subscribing to DataArray chunks over ZeroMQ. + + Parameters + ---------- + address : str + The address to connect the subscriber to. + + Methods + ------- + submit(da) + Send a DataArray over ZeroMQ. + + Examples + -------- + >>> import threading + + >>> import xdas as xd + >>> from xdas.processing import ZMQSubscriber + + First we generate some data and split it into packets + + >>> da = xd.synthetics.dummy() + >>> packets = xd.split(da, 10) + + We then publish the packets asynchronously + + >>> address = f"tcp://localhost:{xd.io.get_free_port()}" + >>> publisher = ZMQPublisher(address) + + >>> def publish(): + ... for packet in packets: + ... publisher.submit(packet) + + >>> threading.Thread(target=publish).start() + + Now let's receive the packets + + >>> subscriber = ZMQSubscriber(address) + >>> packets = [] + >>> for n, da in enumerate(subscriber, start=1): + ... packets.append(da) + ... if n == 10: + ... break + >>> da = xd.concatenate(packets) + >>> assert da.equals(da) + """ + + def __init__(self, address): + self.address = address + self._context = zmq.Context() + self._socket = self._context.socket(zmq.SUB) + self._socket.connect(address) + self._socket.setsockopt_string(zmq.SUBSCRIBE, "") + + def __iter__(self): + return self + + def __next__(self): + message = self._socket.recv() + return frombuffer(message) + + +def tobytes(da, encoding=None): + with TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "tmp.nc") + da.to_netcdf(path, virtual=False, encoding=encoding) + with open(path, "rb") as file: + return file.read() + + +def frombuffer(da): + with TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "tmp.nc") + with open(path, "wb") as file: + file.write(da) + return open_dataarray(path).load() diff --git a/xdas/synthetics.py b/xdas/synthetics.py index b5a4c65..aa12087 100644 --- a/xdas/synthetics.py +++ b/xdas/synthetics.py @@ -45,6 +45,9 @@ def wavelet_wavefronts( True """ + # ensure reporducibility + np.random.seed(42) + # sampling starttime = np.datetime64(starttime).astype("datetime64[ns]") span = (np.timedelta64(6, "s"), 10000.0) # (6 s, 10 km) @@ -68,7 +71,6 @@ def wavelet_wavefronts( data[:, k] += sp.gausspulse(t - ttp[k] - t0, fc) / 2 # P is twice weaker data[:, k] += sp.gausspulse(t - tts[k] - t0, fc) data /= np.max(np.abs(data), axis=0, keepdims=True) # normalize - np.random.seed(42) data += np.random.randn(*shape) / snr # add noise # strain rate like response @@ -96,6 +98,9 @@ def wavelet_wavefronts( def randn_wavefronts(): + # ensure reporducibility + np.random.seed(42) + # sampling resolution = (np.timedelta64(10, "ms"), 100.0) starttime = np.datetime64("2024-01-01T00:00:00").astype("datetime64[ns]") @@ -119,7 +124,6 @@ def randn_wavefronts(): data[:, k] += (t > (t0 + ttp[k])) * np.random.randn(shape[0]) / 2 data[:, k] += (t > (t0 + tts[k])) * np.random.randn(shape[0]) data /= np.max(np.abs(data), axis=0, keepdims=True) # normalize - np.random.seed(42) data += np.random.randn(*shape) / snr # add noise # pack data and coordinates as Database or DataCollection if chunking. @@ -137,3 +141,17 @@ def randn_wavefronts(): }, ) return da + + +def dummy(shape=(1000, 100)): + starttime = np.datetime64("2024-01-01T00:00:00") + endtime = starttime + (shape[0] - 1) * np.timedelta64(100, "ms") + time = {"tie_indices": [0, shape[0] - 1], "tie_values": [starttime, endtime]} + distance = {"tie_indices": [0, shape[1] - 1], "tie_values": [0.0, 1000.0]} + return DataArray( + data=np.random.randn(*shape), + coords={ + "time": time, + "distance": distance, + }, + )