@@ -2,9 +2,68 @@ import CvxLean
2
2
3
3
noncomputable section
4
4
5
- namespace FittingSphere
5
+ open CvxLean Minimization Real BigOperators Matrix Finset
6
+
7
+ section LeastSquares
8
+
9
+ def leastSquares {n : ℕ} (a : Fin n → ℝ) :=
10
+ optimization (x : ℝ)
11
+ minimize (∑ i, ((a i - x) ^ 2 ) : ℝ)
12
+
13
+ @[reducible]
14
+ def mean {n : ℕ} (a : Fin n → ℝ) : ℝ := (1 / n) * ∑ i, (a i)
15
+
16
+ /-- It is useful to rewrite the sum of squares in the following way to prove
17
+ `leastSquares_optimal_eq_mean`, following Marty Cohen's answer in
18
+ https://math.stackexchange.com/questions/2554243. -/
19
+ lemma leastSquares_alt_objFun {n : ℕ} (hn : 0 < n) (a : Fin n → ℝ) (x : ℝ) :
20
+ (∑ i, ((a i - x) ^ 2 )) = n * ((x - mean a) ^ 2 + (mean (a ^ 2 ) - (mean a) ^ 2 )) := by
21
+ calc
22
+ -- 1) Σ (aᵢ - x)² = Σ (aᵢ² - 2aᵢx + x²)
23
+ _ = ∑ i, ((a i) ^ 2 - 2 * (a i) * x + (x ^ 2 )) := by
24
+ congr; funext i; simp; ring
25
+ -- 2) ... = Σ aᵢ² - 2xΣ aᵢ + nx²
26
+ _ = ∑ i, ((a i) ^ 2 ) - 2 * x * ∑ i, (a i) + n * (x ^ 2 ) := by
27
+ rw [sum_add_distrib, sum_sub_distrib, ← sum_mul, ← mul_sum]; simp [sum_const]; ring
28
+ -- 3) ... = n{a²} - 2xn{a} + nx²
29
+ _ = n * mean (a ^ 2 ) - 2 * x * n * mean a + n * (x ^ 2 ) := by
30
+ simp [mean]; field_simp; ring
31
+ -- 4) ... = n((x - {a})² + ({a²} - {a}²))
32
+ _ = n * ((x - mean a) ^ 2 + (mean (a ^ 2 ) - (mean a) ^ 2 )) := by
33
+ simp [mean]; field_simp; ring
34
+
35
+ /-- Key result about least squares: `x* = mean a`. -/
36
+ lemma leastSquares_optimal_eq_mean {n : ℕ} (hn : 0 < n) (a : Fin n → ℝ) (x : ℝ)
37
+ (h : (leastSquares a).optimal x) : x = mean a := by
38
+ simp [optimal, feasible, leastSquares] at h
39
+ replace h : ∀ y, (x - mean a) ^ 2 ≤ (y - mean a) ^ 2 := by
40
+ intros y
41
+ have hy := h y
42
+ have h_rw_x := leastSquares_alt_objFun hn a x
43
+ have h_rw_y := leastSquares_alt_objFun hn a y
44
+ simp only [rpow_two] at h_rw_x h_rw_y ⊢
45
+ rwa [h_rw_x, h_rw_y, mul_le_mul_left (by positivity), add_le_add_iff_right] at hy
46
+ have hmean := h (mean a)
47
+ simp at hmean
48
+ have hz := le_antisymm hmean (sq_nonneg _)
49
+ rwa [sq_eq_zero_iff, sub_eq_zero] at hz
50
+
51
+ def Vec.leastSquares {n : ℕ} (a : Fin n → ℝ) :=
52
+ optimization (x : ℝ)
53
+ minimize (Vec.sum ((a - Vec.const n x) ^ 2 ) : ℝ)
54
+
55
+ /-- Same as `leastSquares_optimal_eq_mean` in vector notation. -/
56
+ lemma vec_leastSquares_optimal_eq_mean {n : ℕ} (hn : 0 < n) (a : Fin n → ℝ) (x : ℝ)
57
+ (h : (Vec.leastSquares a).optimal x) : x = mean a := by
58
+ apply leastSquares_optimal_eq_mean hn a
59
+ simp [Vec.leastSquares, leastSquares, optimal, feasible] at h ⊢
60
+ intros y
61
+ simp only [Vec.sum, Pi.pow_apply, Pi.sub_apply, Vec.const, rpow_two] at h
62
+ exact h y
6
63
7
- open CvxLean Minimization Real BigOperators Matrix
64
+ end LeastSquares
65
+
66
+ namespace FittingSphere
8
67
9
68
-- Dimension.
10
69
variable (n : ℕ)
@@ -19,7 +78,144 @@ def fittingSphere :=
19
78
optimization (c : Fin n → ℝ) (r : ℝ)
20
79
minimize (∑ i, (‖(x i) - c‖ ^ 2 - r ^ 2 ) ^ 2 : ℝ)
21
80
subject to
22
- _ : True
81
+ h₁ : 0 < r
82
+
83
+ instance : ChangeOfVariables fun (ct : (Fin n → ℝ) × ℝ) => (ct.1 , sqrt (ct.2 + ‖ct.1 ‖ ^ 2 )) :=
84
+ { inv := fun (c, r) => (c, r ^ 2 - ‖c‖ ^ 2 ),
85
+ condition := fun (_, t) => 0 ≤ t,
86
+ property := fun ⟨c, t⟩ h => by simp [sqrt_sq h] }
87
+
88
+ set_option trace.Meta.debug true
89
+
90
+ equivalence' eqv/fittingSphereT (n m : ℕ) (x : Fin m → Fin n → ℝ) : fittingSphere n m x := by
91
+ -- Change of variables.
92
+ equivalence_step =>
93
+ apply ChangeOfVariables.toEquivalence
94
+ (fun (ct : (Fin n → ℝ) × ℝ) => (ct.1 , sqrt (ct.2 + ‖ct.1 ‖ ^ 2 )))
95
+ . rintro _ h; exact le_of_lt h
96
+ rename_vars [c, t]
97
+ -- Clean up.
98
+ conv_constr h₁ => dsimp
99
+ conv_obj => dsimp
100
+ -- Rewrite objective.
101
+ equivalence_step =>
102
+ apply Equivalence.rewrite_objFun
103
+ (g := fun (ct : (Fin n → ℝ) × ℝ) =>
104
+ Vec.sum (((Vec.norm x) ^ 2 - 2 * (Matrix.mulVec x ct.1 ) - Vec.const m ct.2 ) ^ 2 ))
105
+ . rintro ⟨c, t⟩ h
106
+ dsimp at h ⊢; simp [Vec.sum, Vec.norm, Vec.const]
107
+ congr; funext i; congr 1 ;
108
+ rw [@norm_sub_sq ℝ (Fin n → ℝ) _ (PiLp.normedAddCommGroup _ _) (PiLp.innerProductSpace _)]
109
+ rw [sq_sqrt (rpow_two _ ▸ le_of_lt (sqrt_pos.mp <| h))]
110
+ simp [mulVec, inner, dotProduct]
111
+ rename_vars [c, t]
112
+
113
+ #print fittingSphereT
114
+
115
+ relaxation rel/fittingSphereConvex (n m : ℕ) (x : Fin m → Fin n → ℝ) : fittingSphereT n m x := by
116
+ relaxation_step =>
117
+ apply Relaxation.weaken_constraint (cs' := fun _ => True)
118
+ . rintro ⟨c, t⟩ _; trivial
119
+
120
+ /-- If the squared error is zero, then `aᵢ = x`. -/
121
+ lemma vec_squared_norm_error_eq_zero_iff {n m : ℕ} (a : Fin m → Fin n → ℝ) (x : Fin n → ℝ) :
122
+ ∑ i, ‖a i - x‖ ^ 2 = 0 ↔ ∀ i, a i = x := by
123
+ simp [rpow_two]
124
+ rw [sum_eq_zero_iff_of_nonneg (fun _ _ => sq_nonneg _)]
125
+ constructor
126
+ . intros h i
127
+ have hi := h i (by simp)
128
+ rw [sq_eq_zero_iff, @norm_eq_zero _ (PiLp.normedAddCommGroup _ _).toNormedAddGroup] at hi
129
+ rwa [sub_eq_zero] at hi
130
+ . intros h i _
131
+ rw [sq_eq_zero_iff, @norm_eq_zero _ (PiLp.normedAddCommGroup _ _).toNormedAddGroup, sub_eq_zero]
132
+ exact h i
133
+
134
+ /-- This tells us that solving the relaxed problem is sufficient for optimal points if the solution
135
+ is non-trivial. -/
136
+ lemma optimal_relaxed_implies_optimal (hm : 0 < m) (c : Fin n → ℝ) (t : ℝ)
137
+ (h_nontrivial : x ≠ Vec.const m c)
138
+ (h_opt : (fittingSphereConvex n m x).optimal (c, t)) : (fittingSphereT n m x).optimal (c, t) := by
139
+ simp [fittingSphereT, fittingSphereConvex, optimal, feasible] at h_opt ⊢
140
+ constructor
141
+ . let a := Vec.norm x ^ 2 - 2 * mulVec x c
142
+ have h_ls : optimal (Vec.leastSquares a) t := by
143
+ refine ⟨trivial, ?_⟩
144
+ intros y _
145
+ simp [objFun, Vec.leastSquares]
146
+ exact h_opt c y
147
+ -- Apply key result about least squares to `a` and `t`.
148
+ have ht_eq := vec_leastSquares_optimal_eq_mean hm a t h_ls
149
+ have hc2_eq : ‖c‖ ^ 2 = (1 / m) * ∑ i : Fin m, ‖c‖ ^ 2 := by
150
+ simp [sum_const]
151
+ field_simp; ring
152
+ have ht : t + ‖c‖ ^ 2 = (1 / m) * ∑ i, ‖(x i) - c‖ ^ 2 := by
153
+ rw [ht_eq]; dsimp [mean]
154
+ rw [hc2_eq, mul_sum, mul_sum, mul_sum, ← sum_add_distrib]
155
+ congr; funext i; rw [← mul_add]
156
+ congr; simp [Vec.norm]
157
+ rw [@norm_sub_sq ℝ (Fin n → ℝ) _ (PiLp.normedAddCommGroup _ _) (PiLp.innerProductSpace _)]
158
+ congr
159
+ -- We use the result to establish that `t + ‖c‖ ^ 2` is non-negative.
160
+ have h_tc2_nonneg : 0 ≤ t + ‖c‖ ^ 2 := by
161
+ rw [ht]
162
+ apply mul_nonneg (by norm_num)
163
+ apply sum_nonneg
164
+ intros i _
165
+ rw [rpow_two]
166
+ exact sq_nonneg _
167
+ cases (lt_or_eq_of_le h_tc2_nonneg) with
168
+ | inl h_tc2_lt_zero =>
169
+ -- If it is positive, we are done.
170
+ convert h_tc2_lt_zero; simp
171
+ | inr h_tc2_eq_zero =>
172
+ -- Otherwise, it contradicts the non-triviality assumption.
173
+ exfalso
174
+ rw [ht, zero_eq_mul] at h_tc2_eq_zero
175
+ rcases h_tc2_eq_zero with (hc | h_sum_eq_zero)
176
+ . simp at hc; linarith
177
+ rw [vec_squared_norm_error_eq_zero_iff] at h_sum_eq_zero
178
+ apply h_nontrivial
179
+ funext i
180
+ exact h_sum_eq_zero i
181
+ . intros c' x' _
182
+ exact h_opt c' x'
183
+
184
+ #print fittingSphereConvex
185
+
186
+ -- We proceed to solve the problem on a concrete example.
187
+ -- https://github.com/cvxgrp/cvxbook_additional_exercises/blob/main/python/sphere_fit_data.py
188
+
189
+ @[optimization_param]
190
+ def nₚ := 2
191
+
192
+ @[optimization_param]
193
+ def mₚ := 10
194
+
195
+ @[optimization_param]
196
+ def xₚ : Fin mₚ → Fin nₚ → ℝ := Matrix.transpose <| ![
197
+ ![1 .824183228637652032e+00 , 1 .349093690455489103e+00 , 6 .966316403935147727e-01 ,
198
+ 7 .599387854623529392e-01 , 2 .388321695850912363e+00 , 8 .651370608981923116e-01 ,
199
+ 1 .863922545015865406e+00 , 7 .099743941474848663e-01 , 6 .005484882320809570e-01 ,
200
+ 4 .561429569892232472e-01 ],
201
+ ![-9 .644136284187876385e-01 , 1 .069547315003422927e+00 , 6 .733229334437943470e-01 ,
202
+ 7 .788072961810316164e-01 , -9 .467465278344706636e-01 , -8 .591303443863639311e-01 ,
203
+ 1 .279527420871080956e+00 , 5 .314829019311283487e-01 , 6 .975676079749143499e-02 ,
204
+ -4 .641873429414754559e-01 ]]
205
+
206
+ -- We use the `solve` command on the data above.
207
+
208
+ set_option maxHeartbeats 1000000
209
+
210
+ solve fittingSphereConvex nₚ mₚ xₚ
211
+
212
+ -- Finally, we recover the solution to the original problem.
213
+
214
+ def sol := eqv.backward_map nₚ mₚ xₚ.float fittingSphereConvex.solution
215
+
216
+ #print eqv.backward_map
217
+
218
+ #eval sol -- (![1.664863, 0.031932], 1.159033)
23
219
24
220
end FittingSphere
25
221
0 commit comments