Skip to content

Commit 78cd02f

Browse files
authored
perf: reduce python overhead for awkward backend (#554)
* awkward backend: reduce python overhead by applying multiple operations in one single broadcasting traversal * add a fast path for noops, otherwise ak.transform does a non-negligible overhead...
1 parent a61d2ef commit 78cd02f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

88 files changed

+402
-118
lines changed

src/vector/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _import_awkward() -> None:
7676
if not typing.TYPE_CHECKING:
7777
VectorAwkward = None
7878
else:
79-
from vector.backends.awkward import VectorAwkward
79+
from vector.backends.awkward import VectorAwkward, awkward_transform
8080

8181
try:
8282
import sympy # type: ignore[import-untyped]
@@ -143,6 +143,7 @@ def _import_awkward() -> None:
143143
"arr",
144144
"array",
145145
"awk",
146+
"awkward_transform",
146147
"dim",
147148
"obj",
148149
"register_awkward",

src/vector/_compute/lorentz/Et.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def dispatch(v: typing.Any) -> typing.Any:
114114
with numpy.errstate(all="ignore"):
115115
return v._wrap_result(
116116
_flavor_of(v),
117-
function(
117+
v._wrap_dispatched_function(function)(
118118
v.lib,
119119
*v.azimuthal.elements,
120120
*v.longitudinal.elements,

src/vector/_compute/lorentz/Et2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def dispatch(v: typing.Any) -> typing.Any:
116116
with numpy.errstate(all="ignore"):
117117
return v._wrap_result(
118118
_flavor_of(v),
119-
function(
119+
v._wrap_dispatched_function(function)(
120120
v.lib,
121121
*v.azimuthal.elements,
122122
*v.longitudinal.elements,

src/vector/_compute/lorentz/Mt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def dispatch(v: typing.Any) -> typing.Any:
110110
with numpy.errstate(all="ignore"):
111111
return v._wrap_result(
112112
_flavor_of(v),
113-
function(
113+
v._wrap_dispatched_function(function)(
114114
v.lib,
115115
*v.azimuthal.elements,
116116
*v.longitudinal.elements,

src/vector/_compute/lorentz/Mt2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def dispatch(v: typing.Any) -> typing.Any:
111111
with numpy.errstate(all="ignore"):
112112
return v._wrap_result(
113113
_flavor_of(v),
114-
function(
114+
v._wrap_dispatched_function(function)(
115115
v.lib,
116116
*v.azimuthal.elements,
117117
*v.longitudinal.elements,

src/vector/_compute/lorentz/add.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,10 @@ def dispatch(v1: typing.Any, v2: typing.Any) -> typing.Any:
201201
),
202202
)
203203
with numpy.errstate(all="ignore"):
204-
return _handler_of(v1, v2)._wrap_result(
204+
handler = _handler_of(v1, v2)
205+
return handler._wrap_result(
205206
_flavor_of(v1, v2),
206-
function(
207+
handler._wrap_dispatched_function(function)(
207208
_lib_of(v1, v2),
208209
*v1.azimuthal.elements,
209210
*v1.longitudinal.elements,

src/vector/_compute/lorentz/beta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def dispatch(v: typing.Any) -> typing.Any:
153153
with numpy.errstate(all="ignore"):
154154
return v._wrap_result(
155155
_flavor_of(v),
156-
function(
156+
v._wrap_dispatched_function(function)(
157157
v.lib,
158158
*v.azimuthal.elements,
159159
*v.longitudinal.elements,

src/vector/_compute/lorentz/boostX_beta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def dispatch(beta: typing.Any, v: typing.Any) -> typing.Any:
243243
with numpy.errstate(all="ignore"):
244244
return v._wrap_result(
245245
_flavor_of(v),
246-
function(
246+
v._wrap_dispatched_function(function)(
247247
v.lib,
248248
beta,
249249
*v.azimuthal.elements,

src/vector/_compute/lorentz/boostX_gamma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def dispatch(gamma: typing.Any, v: typing.Any) -> typing.Any:
243243
with numpy.errstate(all="ignore"):
244244
return v._wrap_result(
245245
_flavor_of(v),
246-
function(
246+
v._wrap_dispatched_function(function)(
247247
v.lib,
248248
gamma,
249249
*v.azimuthal.elements,

src/vector/_compute/lorentz/boostY_beta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def dispatch(beta: typing.Any, v: typing.Any) -> typing.Any:
243243
with numpy.errstate(all="ignore"):
244244
return v._wrap_result(
245245
_flavor_of(v),
246-
function(
246+
v._wrap_dispatched_function(function)(
247247
v.lib,
248248
beta,
249249
*v.azimuthal.elements,

src/vector/_compute/lorentz/boostY_gamma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def dispatch(gamma: typing.Any, v: typing.Any) -> typing.Any:
243243
with numpy.errstate(all="ignore"):
244244
return v._wrap_result(
245245
_flavor_of(v),
246-
function(
246+
v._wrap_dispatched_function(function)(
247247
v.lib,
248248
gamma,
249249
*v.azimuthal.elements,

src/vector/_compute/lorentz/boostZ_beta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def dispatch(beta: typing.Any, v: typing.Any) -> typing.Any:
218218
with numpy.errstate(all="ignore"):
219219
return v._wrap_result(
220220
_flavor_of(v),
221-
function(
221+
v._wrap_dispatched_function(function)(
222222
v.lib,
223223
beta,
224224
*v.azimuthal.elements,

src/vector/_compute/lorentz/boostZ_gamma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def dispatch(gamma: typing.Any, v: typing.Any) -> typing.Any:
218218
with numpy.errstate(all="ignore"):
219219
return v._wrap_result(
220220
_flavor_of(v),
221-
function(
221+
v._wrap_dispatched_function(function)(
222222
v.lib,
223223
gamma,
224224
*v.azimuthal.elements,

src/vector/_compute/lorentz/boost_beta3.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -391,9 +391,10 @@ def dispatch(v1: typing.Any, v2: typing.Any) -> typing.Any:
391391
),
392392
)
393393
with numpy.errstate(all="ignore"):
394-
return _handler_of(v1, v2)._wrap_result(
394+
handler = _handler_of(v1, v2)
395+
return handler._wrap_result(
395396
_flavor_of(v1, v2),
396-
function(
397+
handler._wrap_dispatched_function(function)(
397398
_lib_of(v1, v2),
398399
*v1.azimuthal.elements,
399400
*v1.longitudinal.elements,

src/vector/_compute/lorentz/boost_p4.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -781,9 +781,10 @@ def dispatch(v1: typing.Any, v2: typing.Any) -> typing.Any:
781781
),
782782
)
783783
with numpy.errstate(all="ignore"):
784-
return _handler_of(v1, v2)._wrap_result(
784+
handler = _handler_of(v1, v2)
785+
return handler._wrap_result(
785786
_flavor_of(v1, v2),
786-
function(
787+
handler._wrap_dispatched_function(function)(
787788
_lib_of(v1, v2),
788789
*v1.azimuthal.elements,
789790
*v1.longitudinal.elements,

src/vector/_compute/lorentz/deltaRapidityPhi.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,10 @@ def dispatch(
110110
),
111111
)
112112
with numpy.errstate(all="ignore"):
113-
return _handler_of(v1, v2)._wrap_result(
113+
handler = _handler_of(v1, v2)
114+
return handler._wrap_result(
114115
_flavor_of(v1, v2),
115-
function(
116+
handler._wrap_dispatched_function(function)(
116117
_lib_of(v1, v2),
117118
*v1.azimuthal.elements,
118119
*v1.longitudinal.elements,

src/vector/_compute/lorentz/deltaRapidityPhi2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,10 @@ def dispatch(
106106
),
107107
)
108108
with numpy.errstate(all="ignore"):
109-
return _handler_of(v1, v2)._wrap_result(
109+
handler = _handler_of(v1, v2)
110+
return handler._wrap_result(
110111
_flavor_of(v1, v2),
111-
function(
112+
handler._wrap_dispatched_function(function)(
112113
_lib_of(v1, v2),
113114
*v1.azimuthal.elements,
114115
*v1.longitudinal.elements,

src/vector/_compute/lorentz/dot.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,10 @@ def dispatch(v1: typing.Any, v2: typing.Any) -> typing.Any:
155155
),
156156
)
157157
with numpy.errstate(all="ignore"):
158-
return _handler_of(v1, v2)._wrap_result(
158+
handler = _handler_of(v1, v2)
159+
return handler._wrap_result(
159160
_flavor_of(v1, v2),
160-
function(
161+
handler._wrap_dispatched_function(function)(
161162
_lib_of(v1, v2),
162163
*v1.azimuthal.elements,
163164
*v1.longitudinal.elements,

src/vector/_compute/lorentz/equal.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,10 @@ def dispatch(v1: typing.Any, v2: typing.Any) -> typing.Any:
168168
),
169169
)
170170
with numpy.errstate(all="ignore"):
171-
return _handler_of(v1, v2)._wrap_result(
171+
handler = _handler_of(v1, v2)
172+
return handler._wrap_result(
172173
_flavor_of(v1, v2),
173-
function(
174+
handler._wrap_dispatched_function(function)(
174175
_lib_of(v1, v2),
175176
*v1.azimuthal.elements,
176177
*v1.longitudinal.elements,

src/vector/_compute/lorentz/gamma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def dispatch(v: typing.Any) -> typing.Any:
153153
with numpy.errstate(all="ignore"):
154154
return v._wrap_result(
155155
_flavor_of(v),
156-
function(
156+
v._wrap_dispatched_function(function)(
157157
v.lib,
158158
*v.azimuthal.elements,
159159
*v.longitudinal.elements,

src/vector/_compute/lorentz/is_lightlike.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def dispatch(tolerance: typing.Any, v: typing.Any) -> typing.Any:
6868
with numpy.errstate(all="ignore"):
6969
return v._wrap_result(
7070
_flavor_of(v),
71-
function(
71+
v._wrap_dispatched_function(function)(
7272
v.lib,
7373
tolerance,
7474
*v.azimuthal.elements,

src/vector/_compute/lorentz/is_spacelike.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def dispatch(tolerance: typing.Any, v: typing.Any) -> typing.Any:
6666
with numpy.errstate(all="ignore"):
6767
return v._wrap_result(
6868
_flavor_of(v),
69-
function(
69+
v._wrap_dispatched_function(function)(
7070
v.lib,
7171
tolerance,
7272
*v.azimuthal.elements,

src/vector/_compute/lorentz/is_timelike.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def dispatch(tolerance: typing.Any, v: typing.Any) -> typing.Any:
6666
with numpy.errstate(all="ignore"):
6767
return v._wrap_result(
6868
_flavor_of(v),
69-
function(
69+
v._wrap_dispatched_function(function)(
7070
v.lib,
7171
tolerance,
7272
*v.azimuthal.elements,

src/vector/_compute/lorentz/isclose.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,10 @@ def dispatch(
221221
),
222222
)
223223
with numpy.errstate(all="ignore"):
224-
return _handler_of(v1, v2)._wrap_result(
224+
handler = _handler_of(v1, v2)
225+
return handler._wrap_result(
225226
_flavor_of(v1, v2),
226-
function(
227+
handler._wrap_dispatched_function(function)(
227228
_lib_of(v1, v2),
228229
rtol,
229230
atol,

src/vector/_compute/lorentz/not_equal.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,10 @@ def dispatch(v1: typing.Any, v2: typing.Any) -> typing.Any:
168168
),
169169
)
170170
with numpy.errstate(all="ignore"):
171-
return _handler_of(v1, v2)._wrap_result(
171+
handler = _handler_of(v1, v2)
172+
return handler._wrap_result(
172173
_flavor_of(v1, v2),
173-
function(
174+
handler._wrap_dispatched_function(function)(
174175
_lib_of(v1, v2),
175176
*v1.azimuthal.elements,
176177
*v1.longitudinal.elements,

src/vector/_compute/lorentz/rapidity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def dispatch(v: typing.Any) -> typing.Any:
127127
with numpy.errstate(all="ignore"):
128128
return v._wrap_result(
129129
_flavor_of(v),
130-
function(
130+
v._wrap_dispatched_function(function)(
131131
v.lib,
132132
*v.azimuthal.elements,
133133
*v.longitudinal.elements,

src/vector/_compute/lorentz/scale.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def dispatch(factor: typing.Any, v: typing.Any) -> typing.Any:
181181
with numpy.errstate(all="ignore"):
182182
return v._wrap_result(
183183
_flavor_of(v),
184-
function(
184+
v._wrap_dispatched_function(function)(
185185
v.lib,
186186
factor,
187187
*v.azimuthal.elements,

src/vector/_compute/lorentz/subtract.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,10 @@ def dispatch(v1: typing.Any, v2: typing.Any) -> typing.Any:
201201
),
202202
)
203203
with numpy.errstate(all="ignore"):
204-
return _handler_of(v1, v2)._wrap_result(
204+
handler = _handler_of(v1, v2)
205+
return handler._wrap_result(
205206
_flavor_of(v1, v2),
206-
function(
207+
handler._wrap_dispatched_function(function)(
207208
_lib_of(v1, v2),
208209
*v1.azimuthal.elements,
209210
*v1.longitudinal.elements,

src/vector/_compute/lorentz/t.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ def xy_z_t(lib, x, y, z, t):
3737
return t
3838

3939

40+
xy_z_t.__awkward_transform_allowed__ = False # type:ignore[attr-defined]
41+
42+
4043
def xy_z_tau(lib, x, y, z, tau):
4144
return lib.sqrt(t2.xy_z_tau(lib, x, y, z, tau))
4245

@@ -45,6 +48,9 @@ def xy_theta_t(lib, x, y, theta, t):
4548
return t
4649

4750

51+
xy_theta_t.__awkward_transform_allowed__ = False # type:ignore[attr-defined]
52+
53+
4854
def xy_theta_tau(lib, x, y, theta, tau):
4955
return lib.sqrt(t2.xy_theta_tau(lib, x, y, theta, tau))
5056

@@ -53,6 +59,9 @@ def xy_eta_t(lib, x, y, eta, t):
5359
return t
5460

5561

62+
xy_eta_t.__awkward_transform_allowed__ = False # type:ignore[attr-defined]
63+
64+
5665
def xy_eta_tau(lib, x, y, eta, tau):
5766
return lib.sqrt(t2.xy_eta_tau(lib, x, y, eta, tau))
5867

@@ -61,6 +70,9 @@ def rhophi_z_t(lib, rho, phi, z, t):
6170
return t
6271

6372

73+
rhophi_z_t.__awkward_transform_allowed__ = False # type:ignore[attr-defined]
74+
75+
6476
def rhophi_z_tau(lib, rho, phi, z, tau):
6577
return lib.sqrt(t2.rhophi_z_tau(lib, rho, phi, z, tau))
6678

@@ -69,6 +81,9 @@ def rhophi_theta_t(lib, rho, phi, theta, t):
6981
return t
7082

7183

84+
rhophi_theta_t.__awkward_transform_allowed__ = False # type:ignore[attr-defined]
85+
86+
7287
def rhophi_theta_tau(lib, rho, phi, theta, tau):
7388
return lib.sqrt(t2.rhophi_theta_tau(lib, rho, phi, theta, tau))
7489

@@ -77,6 +92,9 @@ def rhophi_eta_t(lib, rho, phi, eta, t):
7792
return t
7893

7994

95+
rhophi_eta_t.__awkward_transform_allowed__ = False # type:ignore[attr-defined]
96+
97+
8098
def rhophi_eta_tau(lib, rho, phi, eta, tau):
8199
return lib.sqrt(t2.rhophi_eta_tau(lib, rho, phi, eta, tau))
82100

@@ -110,7 +128,7 @@ def dispatch(v: typing.Any) -> typing.Any:
110128
with numpy.errstate(all="ignore"):
111129
return v._wrap_result(
112130
_flavor_of(v),
113-
function(
131+
v._wrap_dispatched_function(function)(
114132
v.lib,
115133
*v.azimuthal.elements,
116134
*v.longitudinal.elements,

src/vector/_compute/lorentz/t2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def dispatch(v: typing.Any) -> typing.Any:
125125
with numpy.errstate(all="ignore"):
126126
return v._wrap_result(
127127
_flavor_of(v),
128-
function(
128+
v._wrap_dispatched_function(function)(
129129
v.lib,
130130
*v.azimuthal.elements,
131131
*v.longitudinal.elements,

0 commit comments

Comments
 (0)