Skip to content

Commit eb2b414

Browse files
Saransh-cppagoose77jpivarski
authored
feat: migrate to awkward v2 (and keep supporting v1) (#284)
* chore: test on awkward v2.0.0rc1 * `RecordArrayType.recordlookup` -> `RecordArrayType.fields` for awkward v2 Co-authored-by: Angus Hollands <goosey15@gmail.com> * Try splitting up CI jobs * Fix doctests * Better job names Co-authored-by: Angus Hollands <goosey15@gmail.com> * Even better job names * Don't depend on pre-commit + --xdoctest * Refactor looking up fields Co-authored-by: Angus Hollands <goosey15@gmail.com> * Use importlib_metadata only * Remove python version restriction * Better error handling Co-authored-by: Angus Hollands <goosey15@gmail.com> * Revert importlib.metdata changes * Explicitly specify awkward v2 Co-authored-by: Angus Hollands <goosey15@gmail.com> * Typo Co-authored-by: Angus Hollands <goosey15@gmail.com> * Typo Co-authored-by: Angus Hollands <goosey15@gmail.com> * Revert mypy config * Try quotes * Private function + uproot-awkward sync * Changelog entry * Try Python 3.8.15. * I've confirmed that uncompyle6 loads for Python 3.8.13. * Trigger the tests. Co-authored-by: Angus Hollands <goosey15@gmail.com> Co-authored-by: Jim Pivarski <jpivarski@gmail.com>
1 parent 1aef1b7 commit eb2b414

File tree

13 files changed

+117
-73
lines changed

13 files changed

+117
-73
lines changed

.github/CONTRIBUTING.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,12 @@ The doctests can be executed using the `test` dependencies of `vector` in the fo
122122
xdoctest ./src/vector/
123123
```
124124

125+
or, one can run the doctests along with the unit tests in the following way -
126+
127+
```bash
128+
python -m pytest --xdoctest .
129+
```
130+
125131
A much more detailed guide on testing with `pytest` for `Scikit-HEP` packages is available [here](https://scikit-hep.org/developer/pytest).
126132

127133
### Running notebook tests

.github/PULL_REQUEST_TEMPLATE.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ _Please describe the purpose of this pull request. Reference and link to any rel
77
## Checklist
88

99
- [ ] Have you followed the guidelines in our Contributing document?
10-
- [ ] Have you checked to ensure there aren't any other open Pull Request for the required change ?
10+
- [ ] Have you checked to ensure there aren't any other open Pull Requests for the required change?
1111
- [ ] Does your submission pass pre-commit? (`$ pre-commit run --all-files` or `$ nox -s lint`)
1212
- [ ] Does your submission pass tests? (`$ pytest` or `$ nox -s tests`)
1313
- [ ] Does the documentation build with your changes? (`$ cd docs; make clean; make html` or `$ nox -s docs`)
14-
- [ ] Does your submission pass doctests? (`$ xdoctest ./src/vector` or `$ nox -s doctests`)
14+
- [ ] Does your submission pass the doctests? (`$ xdoctest ./src/vector` or `$ nox -s doctests`)
1515

1616
## Before Merging
1717

18-
- [ ] Summarize commit messages into a brief review of the Pull request.
18+
- [ ] Summarize the commit messages into a brief review of the Pull request.

.github/workflows/ci.yml

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
echo "::add-matcher::$GITHUB_WORKSPACE/.github/matchers/pylint.json"
3030
pipx run nox -s pylint
3131
32-
checks:
32+
check-light:
3333
runs-on: ubuntu-latest
3434
strategy:
3535
fail-fast: false
@@ -40,7 +40,7 @@ jobs:
4040
- "3.9"
4141
- "3.10"
4242
- "3.11"
43-
name: Check Python ${{ matrix.python-version }}
43+
name: Python ${{ matrix.python-version }} - Light
4444
steps:
4545
- uses: actions/checkout@v3
4646

@@ -57,20 +57,69 @@ jobs:
5757
- name: Test light package
5858
run: python -m pytest -ra --ignore tests/test_notebooks.py
5959

60-
- name: Install develop extras
60+
check-awkward-v1:
61+
needs: [check-light]
62+
runs-on: ubuntu-latest
63+
strategy:
64+
fail-fast: false
65+
matrix:
66+
python-version:
67+
- "3.7"
68+
- "3.8"
69+
- "3.9"
70+
- "3.10"
71+
- "3.11"
72+
name: Python ${{ matrix.python-version }} - Awkward v1
73+
steps:
74+
- uses: actions/checkout@v3
75+
76+
- uses: actions/setup-python@v4
77+
with:
78+
python-version: ${{ matrix.python-version }}
79+
80+
- name: Requirements check
81+
run: python -m pip list
82+
83+
- name: Install package
6184
run: python -m pip install -e .[dev]
6285

86+
- name: Install awkward v1
87+
run: python -m pip install -U "awkward<2"
88+
6389
- name: Test package with awkward v1.x
64-
run: python -m pytest -ra --cov=vector --ignore tests/test_notebooks.py tests/
90+
run: python -m pytest -ra --cov=vector --xdoctest --ignore tests/test_notebooks.py .
6591

66-
- name: Use awkward v1.10.x for v2 support
67-
run: python -m pip install -U --pre "awkward<2"
92+
check-awkward-v2:
93+
needs: [check-light]
94+
runs-on: ubuntu-latest
95+
strategy:
96+
fail-fast: false
97+
matrix:
98+
python-version:
99+
- "3.7"
100+
- "3.8"
101+
- "3.9"
102+
- "3.10"
103+
- "3.11"
104+
name: Python ${{ matrix.python-version }} - Awkward v2
105+
steps:
106+
- uses: actions/checkout@v3
107+
108+
- uses: actions/setup-python@v4
109+
with:
110+
python-version: ${{ matrix.python-version }}
111+
112+
- name: Requirements check
113+
run: python -m pip list
114+
115+
- name: Install package
116+
run: python -m pip install -e .[dev]
68117

69-
- name: Test package with awkward._v2
70-
run: VECTOR_USE_AWKWARDV2=1 python -m pytest -ra --cov=vector --ignore tests/test_notebooks.py tests/
118+
- name: Install awkward v2
119+
run: python -m pip install -U --pre "awkward>=2.0.0rc1"
71120

72-
- name: Run doctests
73-
run: xdoctest ./src/vector/
121+
- name: Test package with awkward v2.x
122+
run: python -m pytest -ra --cov=vector --xdoctest --ignore tests/test_notebooks.py .
74123

75124
- name: Upload coverage report
76125
uses: codecov/codecov-action@v3.1.1
@@ -83,7 +132,7 @@ jobs:
83132

84133
- uses: actions/setup-python@v4
85134
with:
86-
python-version: 3.8.8
135+
python-version: 3.8.13
87136

88137
- name: Requirements check
89138
run: python -m pip list
@@ -95,7 +144,8 @@ jobs:
95144
run: python -m pytest -ra -m dis --ignore tests/test_notebooks.py
96145

97146
pass:
98-
needs: [pre-commit, checks, discheck]
147+
needs:
148+
[pre-commit, check-light, check-awkward-v1, check-awkward-v2, discheck]
99149
runs-on: ubuntu-latest
100150
steps:
101151
- run: echo "All jobs passed"

docs/changelog.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
- Fix missing backslash in latex for readme [#285][]
2424
- chore: ignore flake8 B905 + improve bug report template [#297][]
2525
- chore: better and long term fix for flake8-bugbear [#298][]
26-
- feat: migrate to awkward v2 [#284][]
26+
- feat: migrate to awkward v2 (and keep supporting v1) [#284][]
2727

2828
[#256]: https://github.com/scikit-hep/vector/pull/256
2929
[#254]: https://github.com/scikit-hep/vector/pull/254

environment.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ dependencies:
1010
- numpy >=1.13.3
1111
- root >=6.18.04
1212
- pip:
13-
- "awkward>=1.2.0"
14-
- "uproot==4.*"
13+
- "awkward>=2.0.0"
14+
- "uproot>=5.0.0"
1515
- "scikit-hep-testdata>=0.2.0"
1616
- -e .

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,6 @@ module = [
201201
"numba.*",
202202
"awkward.*",
203203
]
204+
ignore_missing_imports = true
204205
disallow_untyped_defs = false
205206
disallow_untyped_calls = false
206-
ignore_missing_imports = true

src/vector/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,13 @@ def _import_awkward() -> None:
7373
raise ImportError(msg)
7474

7575

76+
_is_awkward_v2: bool | None
77+
try:
78+
_is_awkward_v2 = packaging.version.Version(
79+
importlib_metadata.version("awkward")
80+
) >= packaging.version.Version("2.0.0rc1")
81+
except importlib_metadata.PackageNotFoundError:
82+
_is_awkward_v2 = None
7683
try:
7784
import awkward
7885

src/vector/backends/awkward.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1598,6 +1598,16 @@ class MomentumRecord4D(MomentumAwkward4D, ak.Record): # type: ignore[misc]
15981598

15991599
# implementation of behaviors in Numba ########################################
16001600

1601+
if vector._is_awkward_v2:
1602+
1603+
def _lookup_field(record_type: typing.Any, name: str) -> int:
1604+
return record_type.fields.index(name)
1605+
1606+
else:
1607+
1608+
def _lookup_field(record_type: typing.Any, name: str) -> int:
1609+
return record_type.recordlookup.index(name)
1610+
16011611

16021612
def _arraytype_of(awkwardtype: typing.Any, component: str) -> typing.Any:
16031613
import numba
@@ -1633,36 +1643,36 @@ def _aztype_of(recordarraytype: typing.Any, is_momentum: bool) -> typing.Any:
16331643

16341644
if is_momentum:
16351645
try:
1636-
x_index = recordarraytype.recordlookup.index("px")
1646+
x_index = _lookup_field(recordarraytype, "px")
16371647
except ValueError:
16381648
x_index = None
16391649
if x_index is None:
16401650
try:
1641-
x_index = recordarraytype.recordlookup.index("x")
1651+
x_index = _lookup_field(recordarraytype, "x")
16421652
except ValueError:
16431653
x_index = None
16441654
if is_momentum:
16451655
try:
1646-
y_index = recordarraytype.recordlookup.index("py")
1656+
y_index = _lookup_field(recordarraytype, "py")
16471657
except ValueError:
16481658
y_index = None
16491659
if y_index is None:
16501660
try:
1651-
y_index = recordarraytype.recordlookup.index("y")
1661+
y_index = _lookup_field(recordarraytype, "y")
16521662
except ValueError:
16531663
y_index = None
16541664
if is_momentum:
16551665
try:
1656-
rho_index = recordarraytype.recordlookup.index("pt")
1666+
rho_index = _lookup_field(recordarraytype, "pt")
16571667
except ValueError:
16581668
rho_index = None
16591669
if rho_index is None:
16601670
try:
1661-
rho_index = recordarraytype.recordlookup.index("rho")
1671+
rho_index = _lookup_field(recordarraytype, "rho")
16621672
except ValueError:
16631673
rho_index = None
16641674
try:
1665-
phi_index = recordarraytype.recordlookup.index("phi")
1675+
phi_index = _lookup_field(recordarraytype, "phi")
16661676
except ValueError:
16671677
phi_index = None
16681678

@@ -1702,20 +1712,20 @@ def _ltype_of(recordarraytype: typing.Any, is_momentum: bool) -> typing.Any:
17021712

17031713
if is_momentum:
17041714
try:
1705-
z_index = recordarraytype.recordlookup.index("pz")
1715+
z_index = _lookup_field(recordarraytype, "pz")
17061716
except ValueError:
17071717
z_index = None
17081718
if z_index is None:
17091719
try:
1710-
z_index = recordarraytype.recordlookup.index("z")
1720+
z_index = _lookup_field(recordarraytype, "z")
17111721
except ValueError:
17121722
z_index = None
17131723
try:
1714-
theta_index = recordarraytype.recordlookup.index("theta")
1724+
theta_index = _lookup_field(recordarraytype, "theta")
17151725
except ValueError:
17161726
theta_index = None
17171727
try:
1718-
eta_index = recordarraytype.recordlookup.index("eta")
1728+
eta_index = _lookup_field(recordarraytype, "eta")
17191729
except ValueError:
17201730
eta_index = None
17211731

@@ -1754,42 +1764,42 @@ def _ttype_of(recordarraytype: typing.Any, is_momentum: bool) -> typing.Any:
17541764

17551765
if is_momentum:
17561766
try:
1757-
t_index = recordarraytype.recordlookup.index("E")
1767+
t_index = _lookup_field(recordarraytype, "E")
17581768
except ValueError:
17591769
t_index = None
17601770
if is_momentum and t_index is None:
17611771
try:
1762-
t_index = recordarraytype.recordlookup.index("e")
1772+
t_index = _lookup_field(recordarraytype, "e")
17631773
except ValueError:
17641774
t_index = None
17651775
if is_momentum and t_index is None:
17661776
try:
1767-
t_index = recordarraytype.recordlookup.index("energy")
1777+
t_index = _lookup_field(recordarraytype, "energy")
17681778
except ValueError:
17691779
t_index = None
17701780
if t_index is None:
17711781
try:
1772-
t_index = recordarraytype.recordlookup.index("t")
1782+
t_index = _lookup_field(recordarraytype, "t")
17731783
except ValueError:
17741784
t_index = None
17751785
if is_momentum:
17761786
try:
1777-
tau_index = recordarraytype.recordlookup.index("M")
1787+
tau_index = _lookup_field(recordarraytype, "M")
17781788
except ValueError:
17791789
tau_index = None
17801790
if is_momentum and tau_index is None:
17811791
try:
1782-
tau_index = recordarraytype.recordlookup.index("m")
1792+
tau_index = _lookup_field(recordarraytype, "m")
17831793
except ValueError:
17841794
tau_index = None
17851795
if is_momentum and tau_index is None:
17861796
try:
1787-
tau_index = recordarraytype.recordlookup.index("mass")
1797+
tau_index = _lookup_field(recordarraytype, "mass")
17881798
except ValueError:
17891799
tau_index = None
17901800
if tau_index is None:
17911801
try:
1792-
tau_index = recordarraytype.recordlookup.index("tau")
1802+
tau_index = _lookup_field(recordarraytype, "tau")
17931803
except ValueError:
17941804
tau_index = None
17951805

@@ -1895,7 +1905,11 @@ def _numba_lower(
18951905

18961906
vectorcls = sig.return_type.instance_class
18971907

1898-
fields = sig.args[0].arrayviewtype.type.recordlookup
1908+
fields = (
1909+
sig.args[0].arrayviewtype.type.fields
1910+
if vector._is_awkward_v2
1911+
else sig.args[0].arrayviewtype.type.recordlookup
1912+
)
18991913

19001914
if issubclass(vectorcls, (VectorObject2D, VectorObject3D, VectorObject4D)):
19011915
if issubclass(sig.return_type.azimuthaltype.instance_class, AzimuthalXY):

tests/backends/test_awkward.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55

66
from __future__ import annotations
77

8-
import os
9-
108
import pytest
119

1210
import vector
@@ -81,12 +79,6 @@ def test_rotateZ():
8179
assert out.wow.tolist() == [[99], [], [123]]
8280

8381

84-
# awkward._v2 has not yet registered NumPy dispatch mechanisms
85-
# see https://github.com/scikit-hep/awkward/issues/1638
86-
# TODO: ensure this passes once awkward v2 is out
87-
@pytest.mark.xfail(
88-
strict=True if os.environ.get("VECTOR_USE_AWKWARDV2") is not None else False
89-
)
9082
def test_projection():
9183
array = vector.Array(
9284
[

tests/backends/test_awkward_numba.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55

66
from __future__ import annotations
77

8-
import os
9-
108
import pytest
119

1210
import vector
@@ -19,12 +17,6 @@
1917
pytestmark = [pytest.mark.numba, pytest.mark.awkward]
2018

2119

22-
# awkward._v2 has not yet registered Numba dispatch mechanisms
23-
# see https://github.com/scikit-hep/awkward/discussions/1639
24-
# TODO: ensure this passes once awkward v2 is out
25-
@pytest.mark.xfail(
26-
strict=True if os.environ.get("VECTOR_USE_AWKWARDV2") is not None else False
27-
)
2820
def test():
2921
@numba.njit
3022
def extract(x):

tests/conftest.py

Lines changed: 0 additions & 10 deletions
This file was deleted.

tests/test_compute_features.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
spark_parser = pytest.importorskip("spark_parser")
4646
pytestmark = pytest.mark.dis
4747

48-
4948
Context = collections.namedtuple("Context", ["name", "closure"])
5049

5150

0 commit comments

Comments
 (0)