Skip to content

Commit ba3d8ea

Browse files
authored
Fix LeastSquares visualize on unsorted x-values (#1084)
Closes #1069
1 parent 08c997a commit ba3d8ea

File tree

4 files changed

+26
-12
lines changed

4 files changed

+26
-12
lines changed

.github/workflows/coverage.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ name: Coverage
33
on:
44
pull_request:
55
paths:
6-
- src
7-
- extern
8-
- tests
6+
- src/**
7+
- extern/**
8+
- tests/**
99
- pyproject.toml
1010
- noxfile.py
1111
- CMakeLists.txt
@@ -16,9 +16,9 @@ on:
1616
- develop
1717
- beta/*
1818
paths:
19-
- src
20-
- extern
21-
- tests
19+
- src/**
20+
- extern/**
21+
- tests/**
2222
- pyproject.toml
2323
- noxfile.py
2424
- CMakeLists.txt

.github/workflows/test.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ name: Test
33
on:
44
pull_request:
55
paths:
6-
- src
7-
- extern
8-
- tests
6+
- src/**
7+
- extern/**
8+
- tests/**
99
- pyproject.toml
1010
- CMakeLists.txt
1111
- .github/workflows/test.yml

src/iminuit/cost.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2306,17 +2306,21 @@ def visualize(
23062306

23072307
x, y, ye = self._masked.T
23082308
plt.errorbar(x, y, ye, fmt="ok")
2309+
2310+
xmin = np.min(x)
2311+
xmax = np.max(x)
23092312
if isinstance(model_points, Iterable):
23102313
xm = np.array(model_points)
23112314
ym = self.model(xm, *args)
23122315
elif model_points > 0:
2316+
# beware, x may not be sorted
23132317
if _detect_log_spacing(x):
2314-
xm = np.geomspace(x[0], x[-1], model_points)
2318+
xm = np.geomspace(xmin, xmax, model_points)
23152319
else:
2316-
xm = np.linspace(x[0], x[-1], model_points)
2320+
xm = np.linspace(xmin, xmax, model_points)
23172321
ym = self.model(xm, *args)
23182322
else:
2319-
xm, ym = _smart_sampling(lambda x: self.model(x, *args), x[0], x[-1])
2323+
xm, ym = _smart_sampling(lambda x: self.model(x, *args), xmin, xmax)
23202324
plt.plot(xm, ym)
23212325
return (x, y, ye), (xm, ym)
23222326

tests/test_cost.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,6 +1389,16 @@ def test_LeastSquares_visualize():
13891389
assert_equal(xm, np.linspace(1, 100))
13901390

13911391

1392+
def test_LeastSquares_visualize_unsorted():
1393+
# make sure that xm uses max and min of x, not assuming x is sorted
1394+
pytest.importorskip("matplotlib")
1395+
1396+
c = LeastSquares([1, 0, 3, 2], [2, 1, 4, 3], 0.1, line)
1397+
xm, _ = c.visualize((1, 1))[1]
1398+
assert xm[0] == 0
1399+
assert xm[-1] == 3
1400+
1401+
13921402
def test_LeastSquares_visualize_par_array():
13931403
pytest.importorskip("matplotlib")
13941404

0 commit comments

Comments
 (0)