@@ -16,10 +16,28 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
16
16
17
17
contains
18
18
19
+ elemental subroutine handle_gelsd_info(info,lda,n,ldb,nrhs,err)
20
+ integer(ilp), intent(in) :: info,lda,n,ldb,nrhs
21
+ type(linalg_state_type), intent(out) :: err
22
+
23
+ ! Process output
24
+ select case (info)
25
+ case (0)
26
+ ! Success
27
+ case (:-1)
28
+ err = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid problem size a=',[lda,n], &
29
+ ', b=',[ldb,nrhs])
30
+ case (1:)
31
+ err = linalg_state_type(this,LINALG_ERROR,'SVD did not converge.')
32
+ case default
33
+ err = linalg_state_type(this,LINALG_INTERNAL_ERROR,'catastrophic error')
34
+
35
+ end subroutine handle_gelsd_info
36
+
19
37
#:for rk,rt,ri in RC_KINDS_TYPES
20
38
#:if rk!="xdp"
21
- ! Workspace needed by gesv
22
- elemental subroutine ${ri}$gesv_space (m,n,nrhs,lrwork,liwork,lcwork)
39
+ ! Workspace needed by gelsd
40
+ elemental subroutine ${ri}$gelsd_space (m,n,nrhs,lrwork,liwork,lcwork)
23
41
integer(ilp), intent(in) :: m,n,nrhs
24
42
integer(ilp), intent(out) :: lrwork,liwork,lcwork
25
43
@@ -53,7 +71,7 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
53
71
lcwork = ceiling(1.25*lcwork,kind=ilp)
54
72
liwork = ceiling(1.25*liwork,kind=ilp)
55
73
56
- end subroutine ${ri}$gesv_space
74
+ end subroutine ${ri}$gelsd_space
57
75
58
76
#:endif
59
77
#:endfor
@@ -93,33 +111,87 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
93
111
!> [optional] state return flag. On error if not requested, the code will stop
94
112
type(linalg_state_type), optional, intent(out) :: err
95
113
!> Result array/matrix x[n] or x[n,nrhs]
96
- ${rt}$, allocatable, target :: x${nd}$
114
+ ${rt}$, allocatable, target :: x${nd}$
115
+
116
+ ! Initialize solution with the shape of the rhs
117
+ allocate(x,mold=b)
118
+
119
+ call stdlib_linalg_${ri}$_solve_lstsq_${ndsuf}$(a,b,x,&
120
+ cond=cond,overwrite_a=overwrite_a,rank=rank,err=err)
121
+
122
+ end function stdlib_linalg_${ri}$_lstsq_${ndsuf}$
123
+
124
+ ! Compute the least-squares solution to a real system of linear equations Ax = b
125
+ module subroutine stdlib_linalg_${ri}$_solve_lstsq_${ndsuf}$(a,b,x, &
126
+ real_storage,int_storage#{if rt.startswith('c')}#,cmpl_storage#{endif}#,cond,overwrite_a,rank,err)
127
+
128
+ !!### Summary
129
+ !! Compute least-squares solution to a real system of linear equations \( Ax = b \)
130
+ !!
131
+ !!### Description
132
+ !!
133
+ !! This function computes the least-squares solution of a linear matrix problem.
134
+ !!
135
+ !! param: a Input matrix of size [m,n].
136
+ !! param: b Right-hand-side vector of size [n] or matrix of size [n,nrhs].
137
+ !! param: cond [optional] Real input threshold indicating that singular values `s_i <= cond*maxval(s)`
138
+ !! do not contribute to the matrix rank.
139
+ !! param: overwrite_a [optional] Flag indicating if the input matrix can be overwritten.
140
+ !! param: rank [optional] integer flag returning matrix rank.
141
+ !! param: err [optional] State return flag.
142
+ !! return: x Solution vector of size [n] or solution matrix of size [n,nrhs].
143
+ !!
144
+ !> Input matrix a[n,n]
145
+ ${rt}$, intent(inout), target :: a(:,:)
146
+ !> Right hand side vector or array, b[n] or b[n,nrhs]
147
+ ${rt}$, intent(in) :: b${nd}$
148
+ !> Result array/matrix x[n] or x[n,nrhs]
149
+ ${rt}$, intent(inout), contiguous, target :: x${nd}$
150
+ !> [optional] real working storage space
151
+ real(${rk}$), optional, intent(inout), target :: real_storage(:)
152
+ !> [optional] integer working storage space
153
+ integer(ilp), optional, intent(inout), target :: int_storage(:)
154
+ #:if rt.startswith('c')
155
+ !> [optional] complex working storage space
156
+ ${rt}$, optional, intent(inout), target :: cmpl_storage(:)
157
+ #:endif
158
+ !> [optional] cutoff for rank evaluation: singular values s(i)<=cond*maxval(s) are considered 0.
159
+ real(${rk}$), optional, intent(in) :: cond
160
+ !> [optional] Can A,b data be overwritten and destroyed?
161
+ logical(lk), optional, intent(in) :: overwrite_a
162
+ !> [optional] Return rank of A
163
+ integer(ilp), optional, intent(out) :: rank
164
+ !> [optional] state return flag. On error if not requested, the code will stop
165
+ type(linalg_state_type), optional, intent(out) :: err
97
166
98
167
!! Local variables
99
168
type(linalg_state_type) :: err0
100
- integer(ilp) :: m,n,lda,ldb,nrhs,info,mnmin,mnmax,arank,lrwork,liwork,lcwork
101
- integer(ilp), allocatable :: iwork(:)
169
+ integer(ilp) :: m,n,lda,ldb,nrhs,ldx,nrhsx,info,mnmin,mnmax,arank,lrwork,liwork,lcwork
170
+ integer(ilp) :: nrs,nis,ncs
171
+ integer(ilp), pointer :: iwork(:)
102
172
logical(lk) :: copy_a
103
173
real(${rk}$) :: acond,rcond
104
- real(${rk}$), allocatable :: singular(:),rwork(:)
105
- ${rt}$ , pointer :: xmat(:,:),amat(:, :)
106
- ${rt}$, allocatable :: cwork(:)
174
+ real(${rk}$), allocatable :: singular(:)
175
+ real(${rk}$) , pointer :: rwork( :)
176
+ ${rt}$, pointer :: xmat(:,:),amat(:,:), cwork(:)
107
177
108
178
! Problem sizes
109
179
m = size(a,1,kind=ilp)
110
180
lda = size(a,1,kind=ilp)
111
181
n = size(a,2,kind=ilp)
112
182
ldb = size(b,1,kind=ilp)
113
183
nrhs = size(b ,kind=ilp)/ldb
184
+ ldx = size(x,1,kind=ilp)
185
+ nrhsx = size(x ,kind=ilp)/ldx
114
186
mnmin = min(m,n)
115
187
mnmax = max(m,n)
116
188
117
- if (lda<1 .or. n<1 .or. ldb<1 .or. ldb/=m) then
189
+ if (lda<1 .or. n<1 .or. ldb<1 .or. ldb/=m .or. ldx/=m ) then
118
190
err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid sizes: a=',[lda,n], &
119
- 'b=',[ldb,nrhs])
120
- allocate(x${nde}$)
191
+ 'b=',[ldb,nrhs],' x=',[ldx,nrhsx])
121
192
call linalg_error_handling(err0,err)
122
193
if (present(rank)) rank = 0
194
+ return
123
195
end if
124
196
125
197
! Can A be overwritten? By default, do not overwrite
@@ -137,7 +209,7 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
137
209
endif
138
210
139
211
! Initialize solution with the rhs
140
- allocate(x,source=b)
212
+ x = b
141
213
xmat(1:n,1:nrhs) => x
142
214
143
215
! Singular values array (in decreasing order)
@@ -153,44 +225,71 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
153
225
endif
154
226
if (rcond<0) rcond = epsilon(0.0_${rk}$)*mnmax
155
227
156
- ! Allocate working space
157
- call ${ri}$gesv_space(m,n,nrhs,lrwork,liwork,lcwork)
158
- #:if rt.startswith('complex')
159
- allocate(rwork(lrwork),cwork(lcwork),iwork(liwork))
160
- #:else
161
- allocate(rwork(lrwork),iwork(liwork))
162
- #:endif
228
+ ! Get working space size
229
+ call ${ri}$gelsd_space(m,n,nrhs,lrwork,liwork,lcwork)
230
+
231
+ ! Real working space
232
+ if (present(real_storage)) then
233
+ rwork => real_storage
234
+ else
235
+ allocate(rwork(lrwork))
236
+ endif
237
+ nrs = size(rwork,kind=ilp)
163
238
164
- ! Solve system using singular value decomposition
165
- #:if rt.startswith('complex')
166
- call gelsd(m,n,nrhs,amat,lda,xmat,ldb,singular,rcond,arank,cwork,lrwork,rwork,iwork,info)
167
- #:else
168
- call gelsd(m,n,nrhs,amat,lda,xmat,ldb,singular,rcond,arank,rwork,lrwork,iwork,info)
169
- #:endif
239
+ ! Integer working space
240
+ if (present(int_storage)) then
241
+ iwork => int_storage
242
+ else
243
+ allocate(iwork(liwork))
244
+ endif
245
+ nis = size(iwork,kind=ilp)
246
+
247
+ #:if rt.startswith('complex')
248
+ ! Complex working space
249
+ if (present(cmpl_storage)) then
250
+ cwork => cmpl_storage
251
+ else
252
+ allocate(cwork(lcwork))
253
+ endif
254
+ ncs = size(iwork,kind=ilp)
255
+ #:endif
256
+
257
+ if (nrs<lrwork .or. nis<liwork#{if rt.startswith('c')}# .or. ncs<lcwork#{endif}#) then
258
+ ! Halt on insufficient space
259
+ err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'insufficient working space: ',&
260
+ 'real=',nrs,' should be >=',lrwork, &
261
+ ', int=',nis,' should be >=',liwork &
262
+ #{if rt.startswith('complex')}#,', cmplx=',ncs,' should be >=',lcwork#{endif}#)
263
+
264
+ else
170
265
266
+ ! Solve system using singular value decomposition
267
+ call gelsd(m,n,nrhs,amat,lda,xmat,ldb,singular,rcond,arank, &
268
+ #:if rt.startswith('complex')
269
+ cwork,nrs,rwork,iwork,info)
270
+ #:else
271
+ rwork,nrs,iwork,info)
272
+ #:endif
273
+
274
+ endif
275
+
171
276
! The condition number of A in the 2-norm = S(1)/S(min(m,n)).
172
277
acond = singular(1)/singular(mnmin)
173
278
174
279
! Process output
175
- select case (info)
176
- case (0)
177
- ! Success
178
- case (:-1)
179
- err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid problem size a=',[lda,n], &
180
- ', b=',[ldb,nrhs])
181
- case (1:)
182
- err0 = linalg_state_type(this,LINALG_ERROR,'SVD did not converge.')
183
- case default
184
- err0 = linalg_state_type(this,LINALG_INTERNAL_ERROR,'catastrophic error')
185
- end select
186
-
187
- if (copy_a) deallocate(amat)
280
+ call handle_gelsd_info(info,lda,n,ldb,nrhs,err0)
188
281
189
282
! Process output and return
190
- call linalg_error_handling(err0,err)
283
+ 1 if (copy_a) deallocate(amat)
191
284
if (present(rank)) rank = arank
285
+ if (.not.present(real_storage)) deallocate(rwork)
286
+ if (.not.present(int_storage)) deallocate(iwork)
287
+ #:if rt.startswith('complex')
288
+ if (.not.present(cmpl_storage)) deallocate(cwork)
289
+ #:endif
290
+ call linalg_error_handling(err0,err)
192
291
193
- end function stdlib_linalg_${ri}$_lstsq_ ${ndsuf}$
292
+ end subroutine stdlib_linalg_${ri}$_solve_lstsq_ ${ndsuf}$
194
293
195
294
#:endif
196
295
#:endfor
0 commit comments