Skip to content

Commit

Permalink
Merge pull request #1155 from UXARRAY/rajeeja/use_polars_validation
Browse files Browse the repository at this point in the history
o performance improvements for validation functions
  • Loading branch information
aaronzedwick authored Feb 5, 2025
2 parents f5f5a1b + 9af8dc0 commit ba511bc
Showing 1 changed file with 20 additions and 24 deletions.
44 changes: 20 additions & 24 deletions uxarray/grid/validation.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,25 @@
import numpy as np
import polars as pl
from warnings import warn


from uxarray.constants import ERROR_TOLERANCE, INT_DTYPE


# validation helper functions
def _check_connectivity(grid):
"""Check if all nodes are referenced by at least one element.
"""Check if all nodes are referenced by at least one element."""

If not, the mesh may have hanging nodes and may not a valid UGRID
mesh
"""
# Convert face_node_connectivity to a Polars Series and get unique values
nodes_in_conn = pl.Series(grid.face_node_connectivity.values.flatten()).unique()

# Check if all nodes are referenced by at least one element
# get unique nodes in connectivity
nodes_in_conn = np.unique(grid.face_node_connectivity.values.flatten())
# remove negative indices/fill values from the list
nodes_in_conn = nodes_in_conn[nodes_in_conn >= 0]
# Filter out negative values
nodes_in_conn = nodes_in_conn.filter(nodes_in_conn >= 0)

# check if the size of unique nodes in connectivity is equal to the number of nodes
if nodes_in_conn.size == grid.n_node:
# Check if the size of unique nodes in connectivity is equal to the number of nodes
if len(nodes_in_conn) == grid.n_node:
return True
else:
warn(
"Some nodes may not be referenced by any element. {0} and {1}".format(
nodes_in_conn.size, grid.n_node
),
f"Some nodes may not be referenced by any element. {len(nodes_in_conn)} and {grid.n_node}",
RuntimeWarning,
)
return False
Expand All @@ -35,18 +28,21 @@ def _check_connectivity(grid):
def _check_duplicate_nodes(grid):
"""Check if there are duplicate nodes in the mesh."""

coords = np.vstack([grid.node_lon.values, grid.node_lat.values])
unique_nodes, indices = np.unique(coords, axis=0, return_index=True)
duplicate_indices = np.setdiff1d(np.arange(len(coords)), indices)
# Convert grid to Polars DataFrame
df = pl.DataFrame({"lon": grid.node_lon.values, "lat": grid.node_lat.values})

# Find unique nodes based on 'lon' and 'lat'
unique_df = df.unique(subset=["lon", "lat"], maintain_order=True)

if duplicate_indices.size > 0:
# Find duplicate nodes using an anti-join
duplicate_df = df.join(unique_df, on=["lon", "lat"], how="anti")

# Print duplicate nodes
if not duplicate_df.is_empty():
warn(
"Duplicate nodes found in the mesh. {0} nodes are duplicates.".format(
duplicate_indices.size
),
f"Duplicate nodes found in the mesh. {duplicate_df.shape[0]} nodes are duplicates.",
RuntimeWarning,
)
return False
else:
return True

Expand Down

0 comments on commit ba511bc

Please sign in to comment.