Skip to content

Commit 504f545

Browse files
authored
Added constrained uniform sampling of angles for SE3, SO3, and Quaternions (#139)
1 parent cf115f6 commit 504f545

File tree

5 files changed

+148
-19
lines changed

5 files changed

+148
-19
lines changed

spatialmath/base/quaternions.py

Lines changed: 98 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@
1414
import math
1515
import numpy as np
1616
import spatialmath.base as smb
17+
from spatialmath.base.argcheck import getunit
1718
from spatialmath.base.types import *
19+
import scipy.interpolate as interpolate
20+
from typing import Optional
21+
from functools import lru_cache
1822

1923
_eps = np.finfo(np.float64).eps
2024

21-
2225
def qeye() -> QuaternionArray:
2326
"""
2427
Create an identity quaternion
@@ -843,29 +846,112 @@ def qslerp(
843846
return q0
844847

845848

846-
def qrand() -> UnitQuaternionArray:
849+
def _compute_cdf_sin_squared(theta: float):
847850
"""
848-
Random unit-quaternion
851+
Computes the CDF for the distribution of angular magnitude for uniformly sampled rotations.
852+
853+
:arg theta: angular magnitude
854+
:rtype: float
855+
:return: cdf of a given angular magnitude
856+
:rtype: float
857+
858+
Helper function for uniform sampling of rotations with constrained angular magnitude.
859+
This function returns the integral of the pdf of angular magnitudes (2/pi * sin^2(theta/2)).
860+
"""
861+
return (theta - np.sin(theta)) / np.pi
849862

863+
@lru_cache(maxsize=1)
864+
def _generate_inv_cdf_sin_squared_interp(num_interpolation_points: int = 256) -> interpolate.interp1d:
865+
"""
866+
Computes an interpolation function for the inverse CDF of the distribution of angular magnitude.
867+
868+
:arg num_interpolation_points: number of points to use in the interpolation function
869+
:rtype: int
870+
:return: interpolation function for the inverse cdf of a given angular magnitude
871+
:rtype: interpolate.interp1d
872+
873+
Helper function for uniform sampling of rotations with constrained angular magnitude.
874+
This function returns interpolation function for the inverse of the integral of the
875+
pdf of angular magnitudes (2/pi * sin^2(theta/2)), which is not analytically defined.
876+
"""
877+
cdf_sin_squared_interp_angles = np.linspace(0, np.pi, num_interpolation_points)
878+
cdf_sin_squared_interp_values = _compute_cdf_sin_squared(cdf_sin_squared_interp_angles)
879+
return interpolate.interp1d(cdf_sin_squared_interp_values, cdf_sin_squared_interp_angles)
880+
881+
def _compute_inv_cdf_sin_squared(x: ArrayLike, num_interpolation_points: int = 256) -> ArrayLike:
882+
"""
883+
Computes the inverse CDF of the distribution of angular magnitude.
884+
885+
:arg x: value for cdf of angular magnitudes
886+
:rtype: ArrayLike
887+
:arg num_interpolation_points: number of points to use in the interpolation function
888+
:rtype: int
889+
:return: angular magnitude associate with cdf value
890+
:rtype: ArrayLike
891+
892+
Helper function for uniform sampling of rotations with constrained angular magnitude.
893+
This function returns the angle associated with the cdf value derived form integral of
894+
the pdf of angular magnitudes (2/pi * sin^2(theta/2)), which is not analytically defined.
895+
"""
896+
inv_cdf_sin_squared_interp = _generate_inv_cdf_sin_squared_interp(num_interpolation_points)
897+
return inv_cdf_sin_squared_interp(x)
898+
899+
def qrand(theta_range:Optional[ArrayLike2] = None, unit: str = "rad", num_interpolation_points: int = 256) -> UnitQuaternionArray:
900+
"""
901+
Random unit-quaternion
902+
903+
:arg theta_range: angular magnitude range [min,max], defaults to None.
904+
:type xrange: 2-element sequence, optional
905+
:arg unit: angular units: 'rad' [default], or 'deg'
906+
:type unit: str
907+
:arg num_interpolation_points: number of points to use in the interpolation function
908+
:rtype: int
909+
:arg num_interpolation_points: number of points to use in the interpolation function
910+
:rtype: int
850911
:return: random unit-quaternion
851912
:rtype: ndarray(4)
852913
853-
Computes a uniformly distributed random unit-quaternion which can be
854-
considered equivalent to a random SO(3) rotation.
914+
Computes a uniformly distributed random unit-quaternion, with in a maximum
915+
angular magnitude, which can be considered equivalent to a random SO(3) rotation.
855916
856917
.. runblock:: pycon
857918
858919
>>> from spatialmath.base import qrand, qprint
859920
>>> qprint(qrand())
860921
"""
861-
u = np.random.uniform(low=0, high=1, size=3) # get 3 random numbers in [0,1]
862-
return np.r_[
863-
math.sqrt(1 - u[0]) * math.sin(2 * math.pi * u[1]),
864-
math.sqrt(1 - u[0]) * math.cos(2 * math.pi * u[1]),
865-
math.sqrt(u[0]) * math.sin(2 * math.pi * u[2]),
866-
math.sqrt(u[0]) * math.cos(2 * math.pi * u[2]),
867-
]
922+
if theta_range is not None:
923+
theta_range = getunit(theta_range, unit)
924+
925+
if(theta_range[0] < 0 or theta_range[1] > np.pi or theta_range[0] > theta_range[1]):
926+
ValueError('Invalid angular range. Must be within the range[0, pi].'
927+
+ f' Recieved {theta_range}.')
928+
929+
# Sample axis and angle independently, respecting the CDF of the
930+
# angular magnitude under uniform sampling.
931+
932+
# Sample angle using inverse transform sampling based on CDF
933+
# of the angular distribution (2/pi * sin^2(theta/2))
934+
theta = _compute_inv_cdf_sin_squared(
935+
np.random.uniform(
936+
low=_compute_cdf_sin_squared(theta_range[0]),
937+
high=_compute_cdf_sin_squared(theta_range[1]),
938+
),
939+
num_interpolation_points=num_interpolation_points,
940+
)
941+
# Sample axis uniformly using 3D normal distributed
942+
v = np.random.randn(3)
943+
v /= np.linalg.norm(v)
868944

945+
return np.r_[math.cos(theta / 2), (math.sin(theta / 2) * v)]
946+
else:
947+
u = np.random.uniform(low=0, high=1, size=3) # get 3 random numbers in [0,1]
948+
return np.r_[
949+
math.sqrt(1 - u[0]) * math.sin(2 * math.pi * u[1]),
950+
math.sqrt(1 - u[0]) * math.cos(2 * math.pi * u[1]),
951+
math.sqrt(u[0]) * math.sin(2 * math.pi * u[2]),
952+
math.sqrt(u[0]) * math.cos(2 * math.pi * u[2]),
953+
]
954+
869955

870956
def qmatrix(q: ArrayLike4) -> R4x4:
871957
"""

spatialmath/pose3d.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
from spatialmath.twist import Twist3
3636

37-
from typing import TYPE_CHECKING
37+
from typing import TYPE_CHECKING, Optional
3838
if TYPE_CHECKING:
3939
from spatialmath.quaternion import UnitQuaternion
4040

@@ -455,12 +455,16 @@ def Rz(cls, theta, unit: str = "rad") -> Self:
455455
return cls([smb.rotz(x, unit=unit) for x in smb.getvector(theta)], check=False)
456456

457457
@classmethod
458-
def Rand(cls, N: int = 1) -> Self:
458+
def Rand(cls, N: int = 1, *, theta_range:Optional[ArrayLike2] = None, unit: str = "rad") -> Self:
459459
"""
460460
Construct a new SO(3) from random rotation
461461
462462
:param N: number of random rotations
463463
:type N: int
464+
:param theta_range: angular magnitude range [min,max], defaults to None.
465+
:type xrange: 2-element sequence, optional
466+
:param unit: angular units: 'rad' [default], or 'deg'
467+
:type unit: str
464468
:return: SO(3) rotation matrix
465469
:rtype: SO3 instance
466470
@@ -477,7 +481,7 @@ def Rand(cls, N: int = 1) -> Self:
477481
478482
:seealso: :func:`spatialmath.quaternion.UnitQuaternion.Rand`
479483
"""
480-
return cls([smb.q2r(smb.qrand()) for _ in range(0, N)], check=False)
484+
return cls([smb.q2r(smb.qrand(theta_range=theta_range, unit=unit)) for _ in range(0, N)], check=False)
481485

482486
@overload
483487
@classmethod
@@ -1517,6 +1521,8 @@ def Rand(
15171521
xrange: Optional[ArrayLike2] = (-1, 1),
15181522
yrange: Optional[ArrayLike2] = (-1, 1),
15191523
zrange: Optional[ArrayLike2] = (-1, 1),
1524+
theta_range:Optional[ArrayLike2] = None,
1525+
unit: str = "rad",
15201526
) -> SE3: # pylint: disable=arguments-differ
15211527
"""
15221528
Create a random SE(3)
@@ -1527,6 +1533,10 @@ def Rand(
15271533
:type yrange: 2-element sequence, optional
15281534
:param zrange: z-axis range [min,max], defaults to [-1, 1]
15291535
:type zrange: 2-element sequence, optional
1536+
:param theta_range: angular magnitude range [min,max], defaults to None -> [0,pi].
1537+
:type xrange: 2-element sequence, optional
1538+
:param unit: angular units: 'rad' [default], or 'deg'
1539+
:type unit: str
15301540
:param N: number of random transforms
15311541
:type N: int
15321542
:return: SE(3) matrix
@@ -1557,7 +1567,7 @@ def Rand(
15571567
Z = np.random.uniform(
15581568
low=zrange[0], high=zrange[1], size=N
15591569
) # random values in the range
1560-
R = SO3.Rand(N=N)
1570+
R = SO3.Rand(N=N, theta_range=theta_range, unit=unit)
15611571
return cls(
15621572
[smb.transl(x, y, z) @ smb.r2t(r.A) for (x, y, z, r) in zip(X, Y, Z, R)],
15631573
check=False,

spatialmath/quaternion.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,12 +1225,16 @@ def Rz(cls, angles: ArrayLike, unit: Optional[str] = "rad") -> UnitQuaternion:
12251225
)
12261226

12271227
@classmethod
1228-
def Rand(cls, N: int = 1) -> UnitQuaternion:
1228+
def Rand(cls, N: int = 1, *, theta_range:Optional[ArrayLike2] = None, unit: str = "rad") -> UnitQuaternion:
12291229
"""
12301230
Construct a new random unit quaternion
12311231
12321232
:param N: number of random rotations
12331233
:type N: int
1234+
:param theta_range: angular magnitude range [min,max], defaults to None -> [0,pi].
1235+
:type xrange: 2-element sequence, optional
1236+
:param unit: angular units: 'rad' [default], or 'deg'
1237+
:type unit: str
12341238
:return: random unit-quaternion
12351239
:rtype: UnitQuaternion instance
12361240
@@ -1248,7 +1252,7 @@ def Rand(cls, N: int = 1) -> UnitQuaternion:
12481252
12491253
:seealso: :meth:`UnitQuaternion.Rand`
12501254
"""
1251-
return cls([smb.qrand() for i in range(0, N)], check=False)
1255+
return cls([smb.qrand(theta_range=theta_range, unit=unit) for i in range(0, N)], check=False)
12521256

12531257
@classmethod
12541258
def Eul(cls, *angles: List[float], unit: Optional[str] = "rad") -> UnitQuaternion:

tests/test_pose3d.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,19 @@ def test_constructor(self):
7272
array_compare(R, np.eye(3))
7373
self.assertIsInstance(R, SO3)
7474

75+
np.random.seed(32)
7576
# random
7677
R = SO3.Rand()
7778
nt.assert_equal(len(R), 1)
7879
self.assertIsInstance(R, SO3)
7980

81+
# random constrained
82+
R = SO3.Rand(theta_range=(0.1, 0.7))
83+
self.assertIsInstance(R, SO3)
84+
self.assertEqual(R.A.shape, (3, 3))
85+
self.assertLessEqual(R.angvec()[0], 0.7)
86+
self.assertGreaterEqual(R.angvec()[0], 0.1)
87+
8088
# copy constructor
8189
R = SO3.Rx(pi / 2)
8290
R2 = SO3(R)
@@ -816,12 +824,13 @@ def test_constructor(self):
816824
array_compare(R, np.eye(4))
817825
self.assertIsInstance(R, SE3)
818826

827+
np.random.seed(65)
819828
# random
820829
R = SE3.Rand()
821830
nt.assert_equal(len(R), 1)
822831
self.assertIsInstance(R, SE3)
823832

824-
# random
833+
# random
825834
T = SE3.Rand()
826835
R = T.R
827836
t = T.t
@@ -847,6 +856,13 @@ def test_constructor(self):
847856
nt.assert_equal(TT.y, ones * t[1])
848857
nt.assert_equal(TT.z, ones * t[2])
849858

859+
# random constrained
860+
T = SE3.Rand(theta_range=(0.1, 0.7))
861+
self.assertIsInstance(T, SE3)
862+
self.assertEqual(T.A.shape, (4, 4))
863+
self.assertLessEqual(T.angvec()[0], 0.7)
864+
self.assertGreaterEqual(T.angvec()[0], 0.1)
865+
850866
# copy constructor
851867
R = SE3.Rx(pi / 2)
852868
R2 = SE3(R)

tests/test_quaternion.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,19 @@ def test_constructor_variants(self):
4848
nt.assert_array_almost_equal(
4949
UnitQuaternion.Rz(-90, "deg").vec, np.r_[1, 0, 0, -1] / math.sqrt(2)
5050
)
51+
52+
np.random.seed(73)
53+
q = UnitQuaternion.Rand(theta_range=(0.1, 0.7))
54+
self.assertIsInstance(q, UnitQuaternion)
55+
self.assertLessEqual(q.angvec()[0], 0.7)
56+
self.assertGreaterEqual(q.angvec()[0], 0.1)
57+
58+
59+
q = UnitQuaternion.Rand(theta_range=(0.1, 0.7))
60+
self.assertIsInstance(q, UnitQuaternion)
61+
self.assertLessEqual(q.angvec()[0], 0.7)
62+
self.assertGreaterEqual(q.angvec()[0], 0.1)
63+
5164

5265
def test_constructor(self):
5366
qcompare(UnitQuaternion(), [1, 0, 0, 0])

0 commit comments

Comments
 (0)