Skip to content

Commit

Permalink
add parquet support (#192)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoOkuma authored Feb 12, 2025
1 parent 2233c39 commit 1e8d502
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 30 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
"pillow >=10.0.0",
"psycopg2-binary >=2.9.6",
"psygnal >=0.9.0",
"pyarrow >=16.1.0,<20",
"pydantic >= 2",
"pydot >=2.0.0",
"qtawesome >=1.3.1",
Expand Down Expand Up @@ -118,6 +119,7 @@ uvicorn = ">=0.27.0.post1"
websocket = ">=0.2.1"
websockets = ">=12.0"
zarr = ">=2.15.0,<3.0.0"
pyarrow = ">=16.1.0,<20"

[tool.pixi.feature.cuda]
channels = ["conda-forge", "rapidsai"]
Expand Down
7 changes: 6 additions & 1 deletion ultrack/core/export/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def export_tracks_by_extension(
Supported file extensions are .xml, .csv, .zarr, .dot, and .json.
- `.xml` exports to a TrackMate compatible XML file.
- `.csv` exports to a CSV file.
- `.parquet` exports to a Parquet file.
- `.zarr` exports the tracks to dense segments in a `zarr` array format.
- `.dot` exports to a Graphviz DOT file.
- `.json` exports to a networkx JSON file.
Expand Down Expand Up @@ -60,6 +61,9 @@ def export_tracks_by_extension(
elif file_ext.lower() == ".zarr":
df, _ = to_tracks_layer(config)
tracks_to_zarr(config, df, filename, overwrite=True)
elif file_ext.lower() == ".parquet":
df, _ = to_tracks_layer(config)
df.to_parquet(filename)
elif file_ext.lower() == ".dot":
G = to_networkx(config)
nx.drawing.nx_pydot.write_dot(G, filename)
Expand All @@ -70,5 +74,6 @@ def export_tracks_by_extension(
json.dump(json_data, f)
else:
raise ValueError(
f"Unknown file extension: {file_ext}. Supported extensions are .xml, .csv, .zarr, .dot, and .json."
f"Unknown file extension: {file_ext}. "
"Supported extensions are .xml, .csv, .zarr, .parquet, .dot, and .json."
)
2 changes: 1 addition & 1 deletion ultrack/napari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ contributions:
readers:
- command: ultrack.get_reader
accepts_directories: false
filename_patterns: ['*.csv']
filename_patterns: ['*.csv', '*.parquet']

# writers:
# - command: ultrack.write_multiple
Expand Down
57 changes: 40 additions & 17 deletions ultrack/reader/_test/test_napari_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,18 @@ def tracks_df(n_nodes: int = 10) -> pd.DataFrame:
return pd.DataFrame(tracks_data, columns=["track_id", "t", "z", "y", "x"])


def test_reader(tracks_df: pd.DataFrame, tmp_path: Path):
reader = napari_get_reader("tracks.csv")
@pytest.mark.parametrize("file_ext", ["csv", "parquet"])
def test_reader(tracks_df: pd.DataFrame, tmp_path: Path, file_ext: str):
reader = napari_get_reader(f"tracks.{file_ext}")
assert reader is None

path = tmp_path / "good_tracks.csv"
path = tmp_path / f"good_tracks.{file_ext}"
tracks_df["node_id"] = np.arange(len(tracks_df)) + 1
tracks_df["labels"] = np.random.randint(2, size=len(tracks_df))
tracks_df.to_csv(path, index=False)
if file_ext == "csv":
tracks_df.to_csv(path, index=False)
else:
tracks_df.to_parquet(path)

reader = napari_get_reader(path)
assert callable(reader)
Expand All @@ -47,13 +51,17 @@ def test_reader(tracks_df: pd.DataFrame, tmp_path: Path):
assert np.allclose(data, tracks_df[["track_id", "t", "z", "y", "x"]])


def test_reader_2d(tracks_df: pd.DataFrame, tmp_path: Path):
reader = napari_get_reader("tracks.csv")
@pytest.mark.parametrize("file_ext", ["csv", "parquet"])
def test_reader_2d(tracks_df: pd.DataFrame, tmp_path: Path, file_ext: str):
reader = napari_get_reader(f"tracks.{file_ext}")
assert reader is None

path = tmp_path / "good_tracks.csv"
path = tmp_path / f"good_tracks.{file_ext}"
tracks_df = tracks_df.drop(columns=["z"])
tracks_df.to_csv(path, index=False)
if file_ext == "csv":
tracks_df.to_csv(path, index=False)
else:
tracks_df.to_parquet(path)

reader = napari_get_reader(path)
assert callable(reader)
Expand All @@ -64,7 +72,8 @@ def test_reader_2d(tracks_df: pd.DataFrame, tmp_path: Path):
assert np.allclose(data, tracks_df[["track_id", "t", "y", "x"]])


def test_reader_with_lineage(tmp_path: Path):
@pytest.mark.parametrize("file_ext", ["csv", "parquet"])
def test_reader_with_lineage(tmp_path: Path, file_ext: str):
tracks_df = pd.DataFrame(
{
"track_id": [1, 1, 2, 3],
Expand All @@ -76,8 +85,11 @@ def test_reader_with_lineage(tmp_path: Path):
}
)

path = tmp_path / "tracks.csv"
tracks_df.to_csv(path, index=False)
path = tmp_path / f"tracks.{file_ext}"
if file_ext == "csv":
tracks_df.to_csv(path, index=False)
else:
tracks_df.to_parquet(path)

reader = napari_get_reader(path)
assert callable(reader)
Expand All @@ -95,26 +107,37 @@ def test_non_existing_track():
assert reader is None


def test_wrong_columns_track(tracks_df: pd.DataFrame, tmp_path: Path):
reader = napari_get_reader("tracks.csv")
@pytest.mark.parametrize("file_ext", ["csv", "parquet"])
def test_wrong_columns_track(tracks_df: pd.DataFrame, tmp_path: Path, file_ext: str):
reader = napari_get_reader(f"tracks.{file_ext}")
assert reader is None

path = tmp_path / "bad_tracks.csv"
path = tmp_path / f"bad_tracks.{file_ext}"
tracks_df = tracks_df.rename(columns={"track_id": "id"})
tracks_df.to_csv(path, index=False)
if file_ext == "csv":
tracks_df.to_csv(path, index=False)
else:
tracks_df.to_parquet(path)

reader = napari_get_reader(path)
assert reader is None


@pytest.mark.parametrize("file_ext", ["csv", "parquet"])
def test_napari_viewer_open_tracks(
make_napari_viewer: Callable[[], ViewerModel],
tracks_df: pd.DataFrame,
tmp_path: Path,
file_ext: str,
) -> None:

_initialize_plugins()

tracks_df.to_csv(tmp_path / "tracks.csv", index=False)
path = tmp_path / f"tracks.{file_ext}"
if file_ext == "csv":
tracks_df.to_csv(path, index=False)
else:
tracks_df.to_parquet(path)

viewer = make_napari_viewer()
viewer.open(tmp_path / "tracks.csv", plugin="ultrack")
viewer.open(path, plugin="ultrack")
31 changes: 20 additions & 11 deletions ultrack/reader/napari_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Callable, List, Union

import pandas as pd
import pyarrow.parquet as pq
from napari.types import LayerDataTuple

from ultrack.tracks.graph import inv_tracks_df_forest
Expand Down Expand Up @@ -46,15 +47,21 @@ def napari_get_reader(

LOG.info(f"Reading tracks from {path}")

if not path.name.endswith(".csv"):
LOG.info(f"{path} must end with `.csv`.")
file_name = path.name.lower()

if not file_name.endswith(".csv") and not file_name.endswith(".parquet"):
LOG.info(f"{path} must end with `.csv` or `.parquet`.")
return None

if not path.exists():
LOG.info(f"{path} does not exist.")
return None

header = pd.read_csv(path, nrows=0).columns.tolist()
if file_name.endswith(".csv"):
header = pd.read_csv(path, nrows=0).columns.tolist()
else:
header = pq.read_table(path).schema.names

LOG.info(f"Tracks file header: {header}")

for colname in TRACKS_HEADER:
Expand All @@ -68,14 +75,14 @@ def napari_get_reader(
return reader_function


def read_csv(path: Union[Path, str]) -> LayerDataTuple:
def read_dataframe(path: Union[Path, str]) -> LayerDataTuple:
"""
Read track data from a CSV file.
Read track data from a CSV or Parquet file.
Parameters
----------
path : Union[Path, str]
Path to the CSV file.
Path to the CSV or Parquet file.
Returns
-------
Expand All @@ -90,10 +97,12 @@ def read_csv(path: Union[Path, str]) -> LayerDataTuple:
If the CSV file contains a 'parent_track_id' column, a track lineage graph
is constructed.
"""
if isinstance(path, str):
path = Path(path)

df = pd.read_csv(path)
path = Path(path)
file_name = path.name.lower()
if file_name.endswith(".csv"):
df = pd.read_csv(path)
elif file_name.endswith(".parquet"):
df = pd.read_parquet(path)

LOG.info(f"Read {len(df)} tracks from {path}")
LOG.info(df.head())
Expand Down Expand Up @@ -132,4 +141,4 @@ def reader_function(path: Union[List[str], str]) -> List:
List of track data tuples.
"""
paths = [path] if isinstance(path, (str, Path)) else path
return [read_csv(p) for p in paths]
return [read_dataframe(p) for p in paths]

0 comments on commit 1e8d502

Please sign in to comment.