Skip to content

Commit 1497fc2

Browse files
committed
Test overlap
1 parent 57b6e9c commit 1497fc2

File tree

2 files changed

+27
-47
lines changed

2 files changed

+27
-47
lines changed

src/phyjax2d/impl.py

+12
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@
1717
T = TypeVar("T")
1818
TWO_PI = jnp.pi * 2
1919

20+
_ALL_POLYGON_KEYS = [
21+
"triangle",
22+
"static_triangle",
23+
"quadrangle",
24+
"static_quadrangle",
25+
"pentagon",
26+
"static_pentagon",
27+
]
28+
2029

2130
def then(x: Any, f: Callable[[Any], Any]) -> Any:
2231
if x is None:
@@ -511,6 +520,9 @@ def zeros(n: int) -> Self:
511520
label=jnp.zeros(n, dtype=jnp.uint8),
512521
)
513522

523+
def is_empty(self) -> bool:
524+
return self.p.batch_size() == 0
525+
514526
def apply_force_global(self, point: jax.Array, force: jax.Array) -> Self:
515527
chex.assert_equal_shape((self.f.xy, force))
516528
xy = self.f.xy + force

src/phyjax2d/utils.py

+15-47
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import numpy as np
1414

1515
from phyjax2d.impl import (
16+
_ALL_POLYGON_KEYS,
1617
Capsule,
1718
Circle,
1819
Polygon,
@@ -529,7 +530,7 @@ def _circle_polygon_overlap(
529530
# Suppose that pstate.p.xy.shape == (N, 2) and xy.shape == (2,)
530531
cxy = pstate.p.inv_transform(jnp.expand_dims(xy, axis=0))
531532
p2cxy = jnp.expand_dims(cxy, axis=1) - polygon.points
532-
separation = _vmap_dot(polygon.normals, (p2cxy - polygon.points)) # (N, NP)
533+
separation = _vmap_dot(polygon.normals, p2cxy) # (N, NP)
533534
max_sep = jnp.max(separation, axis=1)
534535
i1 = jnp.argmax(separation, axis=1)
535536
i2 = (i1 + 1) % n_vertices
@@ -553,7 +554,7 @@ def circle_overlap(
553554
) -> jax.Array:
554555
# Circle overlap
555556
overlap = jnp.array(False)
556-
if stated.circle is not None and shaped.circle is not None:
557+
if not stated.circle.is_empty():
557558
cpos = stated.circle.p.xy
558559
# Suppose that cpos.shape == (N, 2) and xy.shape == (2,)
559560
dist = jnp.linalg.norm(cpos - jnp.expand_dims(xy, axis=0), axis=-1)
@@ -562,7 +563,7 @@ def circle_overlap(
562563
overlap = jnp.any(has_overlap)
563564

564565
# Static_circle overlap
565-
if stated.static_circle is not None and shaped.static_circle is not None:
566+
if not stated.static_circle.is_empty():
566567
cpos = stated.static_circle.p.xy
567568
# Suppose that cpos.shape == (N, 2) and xy.shape == (2,)
568569
dist = jnp.linalg.norm(cpos - jnp.expand_dims(xy, axis=0), axis=-1)
@@ -571,7 +572,7 @@ def circle_overlap(
571572
overlap = jnp.logical_or(jnp.any(has_overlap), overlap)
572573

573574
# Circle-segment overlap
574-
if stated.segment is not None and shaped.segment is not None:
575+
if not stated.segment.is_empty():
575576
spos = stated.segment.p
576577
# Suppose that spos.shape == (N, 2) and xy.shape == (2,)
577578
pb = spos.inv_transform(jnp.expand_dims(xy, axis=0))
@@ -587,49 +588,16 @@ def circle_overlap(
587588
has_overlap = jnp.logical_and(stated.segment.is_active, penetration >= 0)
588589
overlap = jnp.logical_or(jnp.any(has_overlap), overlap)
589590

590-
# Circle-segment overlap
591-
if stated.segment is not None and shaped.segment is not None:
592-
spos = stated.segment.p
593-
# Suppose that cpos.shape == (N, 2) and xy.shape == (2,)
594-
pb = spos.inv_transform(jnp.expand_dims(xy, axis=0))
595-
p1, p2 = shaped.segment.point1, shaped.segment.point2
596-
edge = p2 - p1
597-
s1 = jnp.expand_dims(_vmap_dot(pb - p1, edge), axis=1)
598-
s2 = jnp.expand_dims(_vmap_dot(p2 - pb, edge), axis=1)
599-
in_segment = jnp.logical_and(s1 >= 0.0, s2 >= 0.0)
600-
ee = jnp.sum(jnp.square(edge), axis=-1, keepdims=True)
601-
pa = jnp.where(in_segment, p1 + edge * s1 / ee, jnp.where(s1 < 0.0, p1, p2))
602-
dist = jnp.linalg.norm(pb - pa, axis=-1)
603-
penetration = radius - dist
604-
has_overlap = jnp.logical_and(stated.segment.is_active, penetration >= 0)
605-
overlap = jnp.logical_or(jnp.any(has_overlap), overlap)
606-
607591
# Circle-polygon overlap
608-
if stated.triangle is not None and shaped.triangle is not None:
609-
has_overlap = _circle_polygon_overlap(
610-
shaped.triangle,
611-
stated.triangle,
612-
xy,
613-
radius,
614-
)
615-
overlap = jnp.logical_or(jnp.any(has_overlap), overlap)
616-
617-
if stated.quadrangle is not None and shaped.quadrangle is not None:
618-
has_overlap = _circle_polygon_overlap(
619-
shaped.quadrangle,
620-
stated.quadrangle,
621-
xy,
622-
radius,
623-
)
624-
overlap = jnp.logical_or(jnp.any(has_overlap), overlap)
625-
626-
if stated.pentagon is not None and shaped.pentagon is not None:
627-
has_overlap = _circle_polygon_overlap(
628-
shaped.pentagon,
629-
stated.pentagon,
630-
xy,
631-
radius,
632-
)
633-
overlap = jnp.logical_or(jnp.any(has_overlap), overlap)
592+
print("before", overlap)
593+
for key in _ALL_POLYGON_KEYS:
594+
if not stated[key].is_empty(): # type: ignore
595+
has_overlap = _circle_polygon_overlap(
596+
shaped[key], # type: ignore
597+
stated[key], # type: ignore
598+
xy,
599+
radius,
600+
)
601+
overlap = jnp.logical_or(jnp.any(has_overlap), overlap)
634602

635603
return overlap

0 commit comments

Comments
 (0)