Skip to content

Commit ffc6bed

Browse files
authored
feat: fitting sphere to data case study
2 parents 90607a7 + 668649e commit ffc6bed

14 files changed

+407
-47
lines changed

CvxLean/Command/Solve/Float/Coeffs.lean

+30-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,19 @@ import CvxLean.Lib.Cones.All
33
import CvxLean.Command.Solve.Float.ProblemData
44
import CvxLean.Command.Solve.Float.RealToFloat
55

6+
7+
/-!
8+
# Extract coefficients from problem to generate problem data
9+
10+
TODO
11+
12+
## TODO
13+
14+
* This is probably a big source of inefficency for the `solve` command. We should come up with
15+
a better way to extract the numerical values from the Lean expressions.
16+
* A first step is to not `unrollVectors` and turn thos expressions into floats directly.
17+
-/
18+
619
namespace CvxLean
720

821
open Lean Meta Elab Tactic
@@ -173,13 +186,24 @@ unsafe def unrollVectors (constraints : Expr) : MetaM (Array Expr) := do
173186
res := res.push (← mkAppM ``Real.expCone #[ai, bi, ci])
174187
-- Vector second-order cone.
175188
| .app (.app (.app (.app (.app (.const ``Real.Vec.soCone _)
176-
exprN@(.app (.const ``Fin _) n)) (.app (.const ``Fin _) m)) finTypeN) t) X =>
189+
exprN@(.app (.const ``Fin _) _n)) (.app (.const ``Fin _) m)) finTypeN) t) X =>
177190
let m : Nat ← evalExpr Nat (mkConst ``Nat) m
178191
for i in [:m] do
179192
let idxExpr ← mkFinIdxExpr i m
180193
let ti := mkApp t idxExpr
181194
let Xi := mkApp X idxExpr
182195
res := res.push (mkAppN (mkConst ``Real.soCone) #[exprN, finTypeN, ti, Xi])
196+
-- Vector rotated second-order cone.
197+
-- Vector second-order cone.
198+
| .app (.app (.app (.app (.app (.app (.const ``Real.Vec.rotatedSoCone _)
199+
exprN@(.app (.const ``Fin _) _n)) (.app (.const ``Fin _) m)) finTypeN) v) w) X =>
200+
let m : Nat ← evalExpr Nat (mkConst ``Nat) m
201+
for i in [:m] do
202+
let idxExpr ← mkFinIdxExpr i m
203+
let vi := mkApp v idxExpr
204+
let wi := mkApp w idxExpr
205+
let Xi := mkApp X idxExpr
206+
res := res.push (mkAppN (mkConst ``Real.rotatedSoCone) #[exprN, finTypeN, vi, wi, Xi])
183207
| _ =>
184208
res := res.push c
185209

@@ -254,7 +278,10 @@ unsafe def determineCoeffsFromExpr (minExpr : Meta.MinimizationExpr) :
254278
let mut idx := 0
255279
for c in cs do
256280
trace[Meta.debug] "Coeffs going through constraint {c}."
281+
let mut isTrivial := false
257282
match Expr.consumeMData c with
283+
| .const ``True _ => do
284+
isTrivial := true
258285
| .app (.const ``Real.zeroCone _) e => do
259286
let e ← realToFloat e
260287
let res ← determineScalarCoeffsAux e p floatDomain
@@ -331,7 +358,8 @@ unsafe def determineCoeffsFromExpr (minExpr : Meta.MinimizationExpr) :
331358
idx := idx + 1
332359
| _ => throwError "No match: {c}."
333360
-- New group, add idx.
334-
sections := sections.push idx
361+
if !isTrivial then
362+
sections := sections.push idx
335363
return (data, sections)
336364

337365
let (objectiveDataA, objectiveDataB) := objectiveData

CvxLean/Command/Solve/Float/RealToFloat.lean

+31-6
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,25 @@ partial def realToFloat (e : Expr) : MetaM Expr := do
2424
for translation in translations do
2525
let (mvars, _, pattern) ← lambdaMetaTelescope translation.real
2626
if ← isDefEq pattern e then
27-
-- TODO: Search for conditions.
27+
-- TODO: Search for conditions.
2828
let args ← mvars.mapM instantiateMVars
2929
return mkAppNBeta translation.float args
3030
else
3131
trace[Meta.debug] "`real-to-float` error: no match for \n{pattern} \n{e}"
3232
match e with
3333
| Expr.app a b => return mkApp (← realToFloat a) (← realToFloat b)
34-
| Expr.lam n ty b d => return mkLambda n d (← realToFloat ty) (← realToFloat b)
35-
| Expr.forallE n ty b d => return mkForall n d (← realToFloat ty) (← realToFloat b)
34+
| Expr.lam n ty b d => do
35+
withLocalDecl n d (← realToFloat ty) fun fvar => do
36+
let b := b.instantiate1 fvar
37+
let bF ← realToFloat b
38+
mkLambdaFVars #[fvar] bF
39+
-- return mkLambda n d (← realToFloat ty) (← realToFloat b)
40+
| Expr.forallE n ty b d => do
41+
withLocalDecl n d (← realToFloat ty) fun fvar => do
42+
let b := b.instantiate1 fvar
43+
let bF ← realToFloat b
44+
mkForallFVars #[fvar] bF
45+
-- return mkForall n d (← realToFloat ty) (← realToFloat b)
3646
| Expr.mdata m e => return mkMData m (← realToFloat e)
3747
| Expr.letE n ty t b _ => return mkLet n (← realToFloat ty) (← realToFloat t) (← realToFloat b)
3848
| Expr.proj typeName idx struct => return mkProj typeName idx (← realToFloat struct)
@@ -183,12 +193,18 @@ addRealToFloat (i) : @instHDiv Real i :=
183193
addRealToFloat (i) : @HPow.hPow Real Nat Real i :=
184194
fun f n => Float.pow f (Float.ofNat n)
185195

186-
addRealToFloat (i) : @HPow.hPow Real Real Real i :=
196+
addRealToFloat (i) : @HPow.hPow.{0, 0, 0} Real Real Real i :=
187197
fun f n => Float.pow f n
188198

189-
addRealToFloat (i) : @instHPow Real i :=
199+
addRealToFloat (β) (i) : @instHPow Real β i :=
190200
@HPow.mk Float Float Float Float.pow
191201

202+
addRealToFloat (n) (i) : @HPow.hPow (Fin n → Real) Real (Fin n → Real) i :=
203+
fun (x : Fin n → Float) (p : Float) (i : Fin n) => Float.pow (x i) p
204+
205+
addRealToFloat (n) (β) (i) : @instHPow (Fin n → Real) β i :=
206+
@HPow.mk (Fin n → Float) Float (Fin n → Float) (fun x p i => Float.pow (x i) p)
207+
192208
addRealToFloat (i) : @LE.le Real i :=
193209
Float.le
194210

@@ -210,6 +226,12 @@ addRealToFloat : @Real.sqrt :=
210226
addRealToFloat : @Real.log :=
211227
Float.log
212228

229+
def Float.norm {n : ℕ} (x : Fin n → Float) : Float :=
230+
Float.sqrt (Vec.Computable.sum (fun i => (Float.pow (x i) 2)))
231+
232+
addRealToFloat (n) (i) : @Norm.norm.{0} (Fin n → ℝ) i :=
233+
@Float.norm n
234+
213235
addRealToFloat (i) : @OfScientific.ofScientific Real i :=
214236
Float.ofScientific
215237

@@ -260,9 +282,12 @@ addRealToFloat (n) (i) (hn) : @Vec.sum.{0} ℝ i (Fin n) hn :=
260282
addRealToFloat (n) (i) (hn) : @Matrix.sum (Fin n) Real hn i :=
261283
@Matrix.Computable.sum n
262284

263-
addRealToFloat (n) : @Vec.cumsum n :=
285+
addRealToFloat (n) (i) : @Vec.cumsum.{0} ℝ i n :=
264286
@Vec.Computable.cumsum n
265287

288+
addRealToFloat : @Vec.norm :=
289+
@Vec.Computable.norm
290+
266291
addRealToFloat (n) (i1) (i2) (i3) : @Matrix.dotProduct (Fin n) ℝ i1 i2 i3 :=
267292
@Matrix.Computable.dotProduct n
268293

CvxLean/Examples/All.lean

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import CvxLean.Examples.CovarianceEstimation
22
import CvxLean.Examples.FittingSphere
33
import CvxLean.Examples.HypersonicShapeDesign
4-
import CvxLean.Examples.OptimalVehicleSpeed
54
import CvxLean.Examples.TrussDesign
5+
import CvxLean.Examples.VehicleSpeedScheduling

CvxLean/Examples/FittingSphere.lean

+199-3
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,68 @@ import CvxLean
22

33
noncomputable section
44

5-
namespace FittingSphere
5+
open CvxLean Minimization Real BigOperators Matrix Finset
6+
7+
section LeastSquares
8+
9+
def leastSquares {n : ℕ} (a : Fin n → ℝ) :=
10+
optimization (x : ℝ)
11+
minimize (∑ i, ((a i - x) ^ 2) : ℝ)
12+
13+
@[reducible]
14+
def mean {n : ℕ} (a : Fin n → ℝ) : ℝ := (1 / n) * ∑ i, (a i)
15+
16+
/-- It is useful to rewrite the sum of squares in the following way to prove
17+
`leastSquares_optimal_eq_mean`, following Marty Cohen's answer in
18+
https://math.stackexchange.com/questions/2554243. -/
19+
lemma leastSquares_alt_objFun {n : ℕ} (hn : 0 < n) (a : Fin n → ℝ) (x : ℝ) :
20+
(∑ i, ((a i - x) ^ 2)) = n * ((x - mean a) ^ 2 + (mean (a ^ 2) - (mean a) ^ 2)) := by
21+
calc
22+
-- 1) Σ (aᵢ - x)² = Σ (aᵢ² - 2aᵢx + x²)
23+
_ = ∑ i, ((a i) ^ 2 - 2 * (a i) * x + (x ^ 2)) := by
24+
congr; funext i; simp; ring
25+
-- 2) ... = Σ aᵢ² - 2xΣ aᵢ + nx²
26+
_ = ∑ i, ((a i) ^ 2) - 2 * x * ∑ i, (a i) + n * (x ^ 2) := by
27+
rw [sum_add_distrib, sum_sub_distrib, ← sum_mul, ← mul_sum]; simp [sum_const]; ring
28+
-- 3) ... = n{a²} - 2xn{a} + nx²
29+
_ = n * mean (a ^ 2) - 2 * x * n * mean a + n * (x ^ 2) := by
30+
simp [mean]; field_simp; ring
31+
-- 4) ... = n((x - {a})² + ({a²} - {a}²))
32+
_ = n * ((x - mean a) ^ 2 + (mean (a ^ 2) - (mean a) ^ 2)) := by
33+
simp [mean]; field_simp; ring
34+
35+
/-- Key result about least squares: `x* = mean a`. -/
36+
lemma leastSquares_optimal_eq_mean {n : ℕ} (hn : 0 < n) (a : Fin n → ℝ) (x : ℝ)
37+
(h : (leastSquares a).optimal x) : x = mean a := by
38+
simp [optimal, feasible, leastSquares] at h
39+
replace h : ∀ y, (x - mean a) ^ 2 ≤ (y - mean a) ^ 2 := by
40+
intros y
41+
have hy := h y
42+
have h_rw_x := leastSquares_alt_objFun hn a x
43+
have h_rw_y := leastSquares_alt_objFun hn a y
44+
simp only [rpow_two] at h_rw_x h_rw_y ⊢
45+
rwa [h_rw_x, h_rw_y, mul_le_mul_left (by positivity), add_le_add_iff_right] at hy
46+
have hmean := h (mean a)
47+
simp at hmean
48+
have hz := le_antisymm hmean (sq_nonneg _)
49+
rwa [sq_eq_zero_iff, sub_eq_zero] at hz
50+
51+
def Vec.leastSquares {n : ℕ} (a : Fin n → ℝ) :=
52+
optimization (x : ℝ)
53+
minimize (Vec.sum ((a - Vec.const n x) ^ 2) : ℝ)
54+
55+
/-- Same as `leastSquares_optimal_eq_mean` in vector notation. -/
56+
lemma vec_leastSquares_optimal_eq_mean {n : ℕ} (hn : 0 < n) (a : Fin n → ℝ) (x : ℝ)
57+
(h : (Vec.leastSquares a).optimal x) : x = mean a := by
58+
apply leastSquares_optimal_eq_mean hn a
59+
simp [Vec.leastSquares, leastSquares, optimal, feasible] at h ⊢
60+
intros y
61+
simp only [Vec.sum, Pi.pow_apply, Pi.sub_apply, Vec.const, rpow_two] at h
62+
exact h y
663

7-
open CvxLean Minimization Real BigOperators Matrix
64+
end LeastSquares
65+
66+
namespace FittingSphere
867

968
-- Dimension.
1069
variable (n : ℕ)
@@ -19,7 +78,144 @@ def fittingSphere :=
1978
optimization (c : Fin n → ℝ) (r : ℝ)
2079
minimize (∑ i, (‖(x i) - c‖ ^ 2 - r ^ 2) ^ 2 : ℝ)
2180
subject to
22-
_ : True
81+
h₁ : 0 < r
82+
83+
instance : ChangeOfVariables fun (ct : (Fin n → ℝ) × ℝ) => (ct.1, sqrt (ct.2 + ‖ct.1‖ ^ 2)) :=
84+
{ inv := fun (c, r) => (c, r ^ 2 - ‖c‖ ^ 2),
85+
condition := fun (_, t) => 0 ≤ t,
86+
property := fun ⟨c, t⟩ h => by simp [sqrt_sq h] }
87+
88+
set_option trace.Meta.debug true
89+
90+
equivalence' eqv/fittingSphereT (n m : ℕ) (x : Fin m → Fin n → ℝ) : fittingSphere n m x := by
91+
-- Change of variables.
92+
equivalence_step =>
93+
apply ChangeOfVariables.toEquivalence
94+
(fun (ct : (Fin n → ℝ) × ℝ) => (ct.1, sqrt (ct.2 + ‖ct.1‖ ^ 2)))
95+
. rintro _ h; exact le_of_lt h
96+
rename_vars [c, t]
97+
-- Clean up.
98+
conv_constr h₁ => dsimp
99+
conv_obj => dsimp
100+
-- Rewrite objective.
101+
equivalence_step =>
102+
apply Equivalence.rewrite_objFun
103+
(g := fun (ct : (Fin n → ℝ) × ℝ) =>
104+
Vec.sum (((Vec.norm x) ^ 2 - 2 * (Matrix.mulVec x ct.1) - Vec.const m ct.2) ^ 2))
105+
. rintro ⟨c, t⟩ h
106+
dsimp at h ⊢; simp [Vec.sum, Vec.norm, Vec.const]
107+
congr; funext i; congr 1;
108+
rw [@norm_sub_sq ℝ (Fin n → ℝ) _ (PiLp.normedAddCommGroup _ _) (PiLp.innerProductSpace _)]
109+
rw [sq_sqrt (rpow_two _ ▸ le_of_lt (sqrt_pos.mp <| h))]
110+
simp [mulVec, inner, dotProduct]
111+
rename_vars [c, t]
112+
113+
#print fittingSphereT
114+
115+
relaxation rel/fittingSphereConvex (n m : ℕ) (x : Fin m → Fin n → ℝ) : fittingSphereT n m x := by
116+
relaxation_step =>
117+
apply Relaxation.weaken_constraint (cs' := fun _ => True)
118+
. rintro ⟨c, t⟩ _; trivial
119+
120+
/-- If the squared error is zero, then `aᵢ = x`. -/
121+
lemma vec_squared_norm_error_eq_zero_iff {n m : ℕ} (a : Fin m → Fin n → ℝ) (x : Fin n → ℝ) :
122+
∑ i, ‖a i - x‖ ^ 2 = 0 ↔ ∀ i, a i = x := by
123+
simp [rpow_two]
124+
rw [sum_eq_zero_iff_of_nonneg (fun _ _ => sq_nonneg _)]
125+
constructor
126+
. intros h i
127+
have hi := h i (by simp)
128+
rw [sq_eq_zero_iff, @norm_eq_zero _ (PiLp.normedAddCommGroup _ _).toNormedAddGroup] at hi
129+
rwa [sub_eq_zero] at hi
130+
. intros h i _
131+
rw [sq_eq_zero_iff, @norm_eq_zero _ (PiLp.normedAddCommGroup _ _).toNormedAddGroup, sub_eq_zero]
132+
exact h i
133+
134+
/-- This tells us that solving the relaxed problem is sufficient for optimal points if the solution
135+
is non-trivial. -/
136+
lemma optimal_relaxed_implies_optimal (hm : 0 < m) (c : Fin n → ℝ) (t : ℝ)
137+
(h_nontrivial : x ≠ Vec.const m c)
138+
(h_opt : (fittingSphereConvex n m x).optimal (c, t)) : (fittingSphereT n m x).optimal (c, t) := by
139+
simp [fittingSphereT, fittingSphereConvex, optimal, feasible] at h_opt ⊢
140+
constructor
141+
. let a := Vec.norm x ^ 2 - 2 * mulVec x c
142+
have h_ls : optimal (Vec.leastSquares a) t := by
143+
refine ⟨trivial, ?_⟩
144+
intros y _
145+
simp [objFun, Vec.leastSquares]
146+
exact h_opt c y
147+
-- Apply key result about least squares to `a` and `t`.
148+
have ht_eq := vec_leastSquares_optimal_eq_mean hm a t h_ls
149+
have hc2_eq : ‖c‖ ^ 2 = (1 / m) * ∑ i : Fin m, ‖c‖ ^ 2 := by
150+
simp [sum_const]
151+
field_simp; ring
152+
have ht : t + ‖c‖ ^ 2 = (1 / m) * ∑ i, ‖(x i) - c‖ ^ 2 := by
153+
rw [ht_eq]; dsimp [mean]
154+
rw [hc2_eq, mul_sum, mul_sum, mul_sum, ← sum_add_distrib]
155+
congr; funext i; rw [← mul_add]
156+
congr; simp [Vec.norm]
157+
rw [@norm_sub_sq ℝ (Fin n → ℝ) _ (PiLp.normedAddCommGroup _ _) (PiLp.innerProductSpace _)]
158+
congr
159+
-- We use the result to establish that `t + ‖c‖ ^ 2` is non-negative.
160+
have h_tc2_nonneg : 0 ≤ t + ‖c‖ ^ 2 := by
161+
rw [ht]
162+
apply mul_nonneg (by norm_num)
163+
apply sum_nonneg
164+
intros i _
165+
rw [rpow_two]
166+
exact sq_nonneg _
167+
cases (lt_or_eq_of_le h_tc2_nonneg) with
168+
| inl h_tc2_lt_zero =>
169+
-- If it is positive, we are done.
170+
convert h_tc2_lt_zero; simp
171+
| inr h_tc2_eq_zero =>
172+
-- Otherwise, it contradicts the non-triviality assumption.
173+
exfalso
174+
rw [ht, zero_eq_mul] at h_tc2_eq_zero
175+
rcases h_tc2_eq_zero with (hc | h_sum_eq_zero)
176+
. simp at hc; linarith
177+
rw [vec_squared_norm_error_eq_zero_iff] at h_sum_eq_zero
178+
apply h_nontrivial
179+
funext i
180+
exact h_sum_eq_zero i
181+
. intros c' x' _
182+
exact h_opt c' x'
183+
184+
#print fittingSphereConvex
185+
186+
-- We proceed to solve the problem on a concrete example.
187+
-- https://github.com/cvxgrp/cvxbook_additional_exercises/blob/main/python/sphere_fit_data.py
188+
189+
@[optimization_param]
190+
def nₚ := 2
191+
192+
@[optimization_param]
193+
def mₚ := 10
194+
195+
@[optimization_param]
196+
def xₚ : Fin mₚ → Fin nₚ → ℝ := Matrix.transpose <| ![
197+
![1.824183228637652032e+00, 1.349093690455489103e+00, 6.966316403935147727e-01,
198+
7.599387854623529392e-01, 2.388321695850912363e+00, 8.651370608981923116e-01,
199+
1.863922545015865406e+00, 7.099743941474848663e-01, 6.005484882320809570e-01,
200+
4.561429569892232472e-01],
201+
![-9.644136284187876385e-01, 1.069547315003422927e+00, 6.733229334437943470e-01,
202+
7.788072961810316164e-01, -9.467465278344706636e-01, -8.591303443863639311e-01,
203+
1.279527420871080956e+00, 5.314829019311283487e-01, 6.975676079749143499e-02,
204+
-4.641873429414754559e-01]]
205+
206+
-- We use the `solve` command on the data above.
207+
208+
set_option maxHeartbeats 1000000
209+
210+
solve fittingSphereConvex nₚ mₚ xₚ
211+
212+
-- Finally, we recover the solution to the original problem.
213+
214+
def sol := eqv.backward_map nₚ mₚ xₚ.float fittingSphereConvex.solution
215+
216+
#print eqv.backward_map
217+
218+
#eval sol -- (![1.664863, 0.031932], 1.159033)
23219

24220
end FittingSphere
25221

0 commit comments

Comments
 (0)