Skip to content

Commit e8ec6e0

Browse files
Adding function to generate trimesh geometry generation from height profile
Adding function to generate trimesh geometry generation from height profile Updated changelog Added unit test in test_geometry.py
1 parent 76681a0 commit e8ec6e0

File tree

3 files changed

+265
-2
lines changed

3 files changed

+265
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1515
- Added `thickness` parameter to `LossyMetalMedium` for computing surface impedance of a thin conductor.
1616
- `priority` field in `Structure` and `MeshOverrideStructure` for setting the behavior in structure overlapping region. When its value is `None`, the priority is automatically determined based on the material property and simulation's `structure_priority_mode`.
1717
- Automatically apply `matplotlib` styles when importing `tidy3d` which can be reverted via the `td.restore_matplotlib_rcparams()` function.
18+
- Added `TriangleMesh.from_height_expression` class method to create a mesh from an analytical height function defined on a 2D grid and `TriangleMesh.from_height_grid` class method to create a mesh from height values sampled on a 2D grid.
1819

1920
### Fixed
2021
- Fixed bug in broadband adjoint source creation when forward simulation had a pulse amplitude greater than 1 or a nonzero pulse phase.

tests/test_components/test_geometry.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,3 +1163,86 @@ def test_triangulation_with_collinear_vertices():
11631163
xr = np.linspace(0, 1, 6)
11641164
a = np.array([[x, -0.5] for x in xr] + [[x, 0.5] for x in xr[::-1]])
11651165
assert len(td.components.geometry.triangulation.triangulate(a)) == 10
1166+
1167+
1168+
def test_triangle_mesh_from_height():
1169+
"""Test the TriangleMesh.from_height_function and from_height_grid constructors."""
1170+
1171+
# Test successful creation with a valid height function
1172+
def valid_height_func(x, y):
1173+
return 0.5 + 0.2 * np.sin(4 * (x + 1)) * np.cos(3 * y)
1174+
1175+
axis = 2
1176+
direction = "+"
1177+
base = 0.0
1178+
center = [0, 0]
1179+
size = [1.5, 2]
1180+
grid_size = [20, 15]
1181+
1182+
geometry_from_func = td.TriangleMesh.from_height_function(
1183+
axis=axis,
1184+
direction=direction,
1185+
base=base,
1186+
center=center,
1187+
size=size,
1188+
grid_size=grid_size,
1189+
height_func=valid_height_func,
1190+
)
1191+
1192+
assert isinstance(geometry_from_func, td.TriangleMesh)
1193+
1194+
# Test equivalence with from_height_grid method
1195+
x = np.linspace(center[0] - 0.5 * size[0], center[0] + 0.5 * size[0], grid_size[0])
1196+
y = np.linspace(center[1] - 0.5 * size[1], center[1] + 0.5 * size[1], grid_size[1])
1197+
x_mesh, y_mesh = np.meshgrid(x, y, indexing="ij")
1198+
1199+
geometry_from_grid = td.TriangleMesh.from_height_grid(
1200+
axis=axis,
1201+
direction=direction,
1202+
base=base,
1203+
grid=(x, y),
1204+
height=valid_height_func(x_mesh, y_mesh),
1205+
)
1206+
1207+
# Check if the two TriangleMesh objects are equivalent
1208+
assert geometry_from_func == geometry_from_grid
1209+
1210+
# Test ValueError for negative height values
1211+
def negative_height_func(x, y):
1212+
return 0.5 + 0.2 * np.sin(4 * (x + 1)) * np.cos(3 * y) - 2
1213+
1214+
with pytest.raises(
1215+
ValueError,
1216+
match="All height values must be non-negative.",
1217+
):
1218+
td.TriangleMesh.from_height_function(
1219+
axis=axis,
1220+
direction=direction,
1221+
base=base,
1222+
center=center,
1223+
size=size,
1224+
grid_size=grid_size,
1225+
height_func=negative_height_func,
1226+
)
1227+
1228+
# Test ValueError for height_func returning ndarray with wrong shape
1229+
def wrong_shape_height_func(x, y):
1230+
return np.zeros((3, 3)) # Incorrect shape
1231+
1232+
expected_shape = (grid_size[0], grid_size[1])
1233+
1234+
# Test for the presence of key parts of the error message
1235+
with pytest.raises(ValueError) as excinfo:
1236+
td.TriangleMesh.from_height_function(
1237+
axis=axis,
1238+
direction=direction,
1239+
base=base,
1240+
center=center,
1241+
size=size,
1242+
grid_size=grid_size,
1243+
height_func=wrong_shape_height_func,
1244+
)
1245+
# Check that the error message contains the expected information
1246+
error_message = str(excinfo.value)
1247+
assert f"shape {expected_shape}" in error_message
1248+
assert "shape (3, 3)" in error_message

tidy3d/components/geometry/mesh.py

Lines changed: 181 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
from __future__ import annotations
44

55
from abc import ABC
6-
from typing import List, Optional, Tuple, Union
6+
from typing import Callable, List, Literal, Optional, Tuple, Union
77

88
import numpy as np
99
import pydantic.v1 as pydantic
1010

11-
from ...constants import inf
11+
from ...constants import fp_eps, inf
1212
from ...exceptions import DataError, ValidationError
1313
from ...log import log
1414
from ...packaging import verify_packages_import
@@ -308,6 +308,185 @@ def _triangles_to_trimesh(
308308

309309
return trimesh.Trimesh(**trimesh.triangles.to_kwargs(triangles))
310310

311+
@classmethod
312+
def from_height_grid(
313+
cls,
314+
axis: Ax,
315+
direction: Literal["-", "+"],
316+
base: float,
317+
grid: Tuple[np.ndarray, np.ndarray],
318+
height: np.ndarray,
319+
) -> TriangleMesh:
320+
"""Construct a TriangleMesh object from grid based height information.
321+
322+
Parameters
323+
----------
324+
axis : Ax
325+
Axis of extrusion.
326+
direction : Literal["-", "+"]
327+
Direction of extrusion.
328+
base : float
329+
Coordinate of the base surface along the geometry's axis.
330+
grid : Tuple[np.ndarray, np.ndarray]
331+
Tuple of two one-dimensional arrays representing the sampling grid (XY, YZ, or ZX
332+
corresponding to values of axis)
333+
height : np.ndarray
334+
Height values sampled on the given grid. Can be 1D (raveled) or 2D (matching grid mesh).
335+
336+
Returns
337+
-------
338+
TriangleMesh
339+
The resulting TriangleMesh geometry object.
340+
"""
341+
342+
x_coords = grid[0]
343+
y_coords = grid[1]
344+
345+
nx = len(x_coords)
346+
ny = len(y_coords)
347+
nt = nx * ny
348+
349+
x_mesh, y_mesh = np.meshgrid(x_coords, y_coords, indexing="ij")
350+
351+
sign = 1
352+
if direction == "-":
353+
sign = -1
354+
355+
flat_height = np.ravel(height)
356+
if flat_height.shape[0] != nt:
357+
raise ValueError(
358+
f"Shape of flattened height array {flat_height.shape} does not match "
359+
f"the number of grid points {nt}."
360+
)
361+
362+
if np.any(flat_height < 0):
363+
raise ValueError("All height values must be non-negative.")
364+
365+
max_h = np.max(flat_height)
366+
min_h_clip = fp_eps * max_h
367+
flat_height = np.clip(flat_height, min_h_clip, inf)
368+
369+
vertices_raw_list = [
370+
[np.ravel(x_mesh), np.ravel(y_mesh), base + sign * flat_height], # Alpha surface
371+
[np.ravel(x_mesh), np.ravel(y_mesh), base * np.ones(nt)],
372+
]
373+
374+
if direction == "-":
375+
vertices_raw_list = vertices_raw_list[::-1]
376+
377+
vertices = np.hstack(vertices_raw_list).T
378+
vertices = np.roll(vertices, shift=axis - 2, axis=1)
379+
380+
q0 = (np.arange(nx - 1)[:, None] * ny + np.arange(ny - 1)[None, :]).ravel()
381+
q1 = (np.arange(1, nx)[:, None] * ny + np.arange(ny - 1)[None, :]).ravel()
382+
q2 = (np.arange(1, nx)[:, None] * ny + np.arange(1, ny)[None, :]).ravel()
383+
q3 = (np.arange(nx - 1)[:, None] * ny + np.arange(1, ny)[None, :]).ravel()
384+
385+
q0_b = nt + q0
386+
q1_b = nt + q1
387+
q2_b = nt + q2
388+
q3_b = nt + q3
389+
390+
top_quads = np.stack((q0, q1, q2, q3), axis=-1)
391+
bottom_quads = np.stack((q0_b, q3_b, q2_b, q1_b), axis=-1)
392+
393+
s1_q0 = (0 * ny + np.arange(ny - 1)).ravel()
394+
s1_q1 = (0 * ny + np.arange(1, ny)).ravel()
395+
s1_q2 = (nt + 0 * ny + np.arange(1, ny)).ravel()
396+
s1_q3 = (nt + 0 * ny + np.arange(ny - 1)).ravel()
397+
side1_quads = np.stack((s1_q0, s1_q1, s1_q2, s1_q3), axis=-1)
398+
399+
s2_q0 = ((nx - 1) * ny + np.arange(ny - 1)).ravel()
400+
s2_q1 = (nt + (nx - 1) * ny + np.arange(ny - 1)).ravel()
401+
s2_q2 = (nt + (nx - 1) * ny + np.arange(1, ny)).ravel()
402+
s2_q3 = ((nx - 1) * ny + np.arange(1, ny)).ravel()
403+
side2_quads = np.stack((s2_q0, s2_q1, s2_q2, s2_q3), axis=-1)
404+
405+
s3_q0 = (np.arange(nx - 1) * ny + 0).ravel()
406+
s3_q1 = (nt + np.arange(nx - 1) * ny + 0).ravel()
407+
s3_q2 = (nt + np.arange(1, nx) * ny + 0).ravel()
408+
s3_q3 = (np.arange(1, nx) * ny + 0).ravel()
409+
side3_quads = np.stack((s3_q0, s3_q1, s3_q2, s3_q3), axis=-1)
410+
411+
s4_q0 = (np.arange(nx - 1) * ny + ny - 1).ravel()
412+
s4_q1 = (np.arange(1, nx) * ny + ny - 1).ravel()
413+
s4_q2 = (nt + np.arange(1, nx) * ny + ny - 1).ravel()
414+
s4_q3 = (nt + np.arange(nx - 1) * ny + ny - 1).ravel()
415+
side4_quads = np.stack((s4_q0, s4_q1, s4_q2, s4_q3), axis=-1)
416+
417+
all_quads = np.vstack(
418+
(top_quads, bottom_quads, side1_quads, side2_quads, side3_quads, side4_quads)
419+
)
420+
421+
triangles_list = [
422+
np.stack((all_quads[:, 0], all_quads[:, 1], all_quads[:, 3]), axis=-1),
423+
np.stack((all_quads[:, 3], all_quads[:, 1], all_quads[:, 2]), axis=-1),
424+
]
425+
tri_faces = np.vstack(triangles_list)
426+
427+
return cls.from_vertices_faces(vertices=vertices, faces=tri_faces)
428+
429+
@classmethod
430+
def from_height_function(
431+
cls,
432+
axis: Ax,
433+
direction: Literal["-", "+"],
434+
base: float,
435+
center: Tuple[float, float],
436+
size: Tuple[float, float],
437+
grid_size: Tuple[int, int],
438+
height_func: Callable[[np.ndarray, np.ndarray], np.ndarray],
439+
) -> TriangleMesh:
440+
"""Construct a TriangleMesh object from analytical expression of height function.
441+
The height function should be vectorized to accept 2D meshgrid arrays.
442+
443+
Parameters
444+
----------
445+
axis : Ax
446+
Axis of extrusion.
447+
direction : Literal["-", "+"]
448+
Direction of extrusion.
449+
base : float
450+
Coordinate of the base rectangle along the geometry's axis.
451+
center : Tuple[float, float]
452+
Center of the base rectangle in the plane perpendicular to the extrusion axis
453+
(XY, YZ, or ZX corresponding to values of axis).
454+
size : Tuple[float, float]
455+
Size of the base rectangle in the plane perpendicular to the extrusion axis
456+
(XY, YZ, or ZX corresponding to values of axis).
457+
grid_size : Tuple[int, int]
458+
Number of grid points for discretization of the base rectangle
459+
(XY, YZ, or ZX corresponding to values of axis).
460+
height_func : Callable[[np.ndarray, np.ndarray], np.ndarray]
461+
Vectorized function to compute height values from 2D meshgrid coordinate arrays.
462+
It should take two ndarrays (x_mesh, y_mesh) and return an ndarray of heights.
463+
464+
Returns
465+
-------
466+
TriangleMesh
467+
The resulting TriangleMesh geometry object.
468+
"""
469+
x_lin = np.linspace(center[0] - 0.5 * size[0], center[0] + 0.5 * size[0], grid_size[0])
470+
y_lin = np.linspace(center[1] - 0.5 * size[1], center[1] + 0.5 * size[1], grid_size[1])
471+
472+
x_mesh, y_mesh = np.meshgrid(x_lin, y_lin, indexing="ij")
473+
474+
height_values = height_func(x_mesh, y_mesh)
475+
476+
if not (isinstance(height_values, np.ndarray) and height_values.shape == x_mesh.shape):
477+
raise ValueError(
478+
f"The 'height_func' must return a NumPy array with shape {x_mesh.shape}, "
479+
f"but got shape {getattr(height_values, 'shape', type(height_values))}."
480+
)
481+
482+
return cls.from_height_grid(
483+
axis=axis,
484+
direction=direction,
485+
base=base,
486+
grid=(x_lin, y_lin),
487+
height=height_values,
488+
)
489+
311490
@cached_property
312491
@verify_packages_import(["trimesh"])
313492
def trimesh(

0 commit comments

Comments
 (0)