Skip to content

Commit

Permalink
Test and Fix combine_by_coords.
Browse files Browse the repository at this point in the history
  • Loading branch information
atrabattoni committed Sep 14, 2024
1 parent 2b4688c commit 0c24614
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 1 deletion.
95 changes: 94 additions & 1 deletion tests/test_routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from xdas.core.coordinates import Coordinates
from xdas.core.dataarray import DataArray
from xdas.core.routines import Bag, CompatibilityError
from xdas.core.routines import Bag, CompatibilityError, combine_by_coords


class TestBag:
Expand Down Expand Up @@ -100,3 +100,96 @@ def test_bag_append_incompatible_sampling_interval(self):
bag.append(da1)
with pytest.raises(CompatibilityError):
bag.append(da2)


class TestCombineByCoords:
def test_basic(self):
# without coords
da1 = DataArray(np.random.rand(10, 5), dims=("time", "space"))
da2 = DataArray(np.random.rand(10, 5), dims=("time", "space"))
combined = combine_by_coords([da1, da2], dim="time", squeeze=True)
assert combined.shape == (20, 5)

# with coords
da1 = DataArray(
np.random.rand(10, 5),
coords={"time": np.arange(10), "space": np.arange(5)},
)
da2 = DataArray(
np.random.rand(10, 5),
coords={"time": np.arange(10, 20), "space": np.arange(5)},
)
combined = combine_by_coords([da1, da2], dim="time", squeeze=True)
assert combined.shape == (20, 5)

def test_incompatible_shape(self):
da1 = DataArray(np.random.rand(10, 5), dims=("time", "space"))
da2 = DataArray(np.random.rand(10, 6), dims=("time", "space"))
dc = combine_by_coords([da1, da2], dim="time")
assert len(dc) == 2
assert dc[0].equals(da1)
assert dc[1].equals(da2)

def test_incompatible_dims(self):
da1 = DataArray(np.random.rand(10, 5), dims=("time", "space"))
da2 = DataArray(np.random.rand(10, 5), dims=("space", "time"))
dc = combine_by_coords([da1, da2], dim="time")
assert len(dc) == 2
assert dc[0].equals(da1)
assert dc[1].equals(da2)

def test_incompatible_dtype(self):
da1 = DataArray(np.random.rand(10, 5), dims=("time", "space"))
da2 = DataArray(np.random.randint(0, 10, size=(10, 5)), dims=("time", "space"))
dc = combine_by_coords([da1, da2], dim="time")
assert len(dc) == 2
assert dc[0].equals(da1)
assert dc[1].equals(da2)

def test_incompatible_coords(self):
da1 = DataArray(
np.random.rand(10, 5),
dims=("time", "space"),
coords={"space": np.arange(5)},
)
da2 = DataArray(
np.random.rand(10, 5),
dims=("time", "space"),
coords={"space": np.arange(5) + 1},
)
dc = combine_by_coords([da1, da2], dim="time")
assert len(dc) == 2
assert dc[0].equals(da1)
assert dc[1].equals(da2)

def test_incompatible_sampling_interval(self):
da1 = DataArray(
np.random.rand(10, 5),
dims=("time", "space"),
coords={"time": np.arange(10)},
)
da2 = DataArray(
np.random.rand(10, 5),
dims=("time", "space"),
coords={"time": np.arange(10) * 2},
)
dc = combine_by_coords([da1, da2], dim="time")
assert len(dc) == 2
assert dc[0].equals(da1)
assert dc[1].equals(da2)

def test_expand_scalar_coordinate(self):
da1 = DataArray(
np.random.rand(10),
dims=("time",),
coords={"time": np.arange(10), "space": 0},
)
da2 = DataArray(
np.random.rand(10),
dims=("time",),
coords={"time": np.arange(10), "space": 1},
)
dc = combine_by_coords([da1, da2], dim="space", squeeze=True)
assert dc.shape == (2, 10)
assert dc.dims == ("space", "time")
assert dc.coords["space"].values.tolist() == [0, 1]
1 change: 1 addition & 0 deletions xdas/core/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ def combine_by_coords(
bags.append(bag)
bag = Bag(dim)
bag.append(da)
bags.append(bag)

# concatenate each bag
collection = DataCollection(
Expand Down

0 comments on commit 0c24614

Please sign in to comment.