@@ -87,7 +87,9 @@ def test_fit(self, k=20, d=11, r=3):
87
87
solver = self .Solver (1 )
88
88
A = np .random .standard_normal ((k , d ))
89
89
B = np .random .standard_normal ((r , k ))
90
+ solver .verify ()
90
91
solver .fit (A , B )
92
+ solver .verify ()
91
93
92
94
for attr , shape in [
93
95
("_ZPhi" , (r , d )),
@@ -101,6 +103,8 @@ def test_fit(self, k=20, d=11, r=3):
101
103
assert isinstance (obj , np .ndarray )
102
104
assert obj .shape == shape
103
105
106
+ repr (solver )
107
+
104
108
def test_predict (self , k = 20 , d = 10 , r = 5 ):
105
109
"""Test predict()."""
106
110
solver1D = self .Solver (0 )
@@ -346,6 +350,8 @@ def test_regularizer(self, k=10, d=6, r=3):
346
350
solver .fit (A , B )
347
351
assert solver .r == r
348
352
353
+ repr (solver )
354
+
349
355
# Main methods ------------------------------------------------------------
350
356
def test_predict (self , k = 20 , d = 10 ):
351
357
"""Test predict()."""
@@ -551,13 +557,20 @@ def test_regularizer(self, k=20, d=11, r=3):
551
557
assert solver .regularizer .shape == (d , d )
552
558
assert np .all (solver .regularizer == 2 * np .eye (d ))
553
559
560
+ repr (solver )
561
+ solver .method = "normal"
562
+ repr (solver )
563
+
554
564
# Main methods ------------------------------------------------------------
555
565
def test_fit (self , k = 20 , d = 10 , r = 5 ):
556
566
"""Test fit()."""
557
567
Z = np .zeros ((d , d ))
558
568
A = np .random .standard_normal ((k , d ))
559
569
B = np .random .standard_normal ((r , k ))
560
- solver = self .Solver (Z ).fit (A , B )
570
+ solver = self .Solver (Z )
571
+ solver .verify (d = d , r = r )
572
+ solver .fit (A , B )
573
+ solver .verify ()
561
574
562
575
for attr , shape in [
563
576
("data_matrix" , (k , d )),
@@ -826,6 +839,8 @@ def test_regularizer(self, k=10, d=6, r=3):
826
839
solver .regularizer = [[i ] * d for i in range (1 , r + 1 )]
827
840
assert np .all (solver .regularizer [0 ] == np .eye (d ))
828
841
842
+ repr (solver )
843
+
829
844
# Main methods ------------------------------------------------------------
830
845
def test_predict (self , k = 20 , d = 10 ):
831
846
"""Test lstsq._tikhonov.TikhonovDecoupledSolver.predict()."""
0 commit comments