Skip to content

Commit 42f2de1

Browse files
committed
subroutine interface
1 parent 1773de8 commit 42f2de1

File tree

2 files changed

+192
-41
lines changed

2 files changed

+192
-41
lines changed

src/stdlib_linalg.fypp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,58 @@ module stdlib_linalg
268268
#:endfor
269269
end interface lstsq
270270

271+
! Least squares solution to system Ax=b, i.e. such that the 2-norm abs(b-Ax) is minimized.
272+
interface solve_lstsq
273+
!! version: experimental
274+
!!
275+
!! Computes the squares solution to system \( A \cdot x = b \).
276+
!! ([Specification](../page/specs/stdlib_linalg.html#XXXXXxXxxxxxxx))xxxxx
277+
!!
278+
!!### Summary
279+
!! Subroutine interface for computing least squares, i.e. the 2-norm \( || (b-A \cdot x ||_2 \) minimizing solution.
280+
!!
281+
!!### Description
282+
!!
283+
!! This interface provides methods for computing the least squares of a linear matrix system using
284+
!! a subroutine. Supported data types include `real` and `complex`. If pre-allocated work spaces
285+
!! are provided, no internal memory allocations take place when using this interface.
286+
!!
287+
!!@note The solution is based on LAPACK's singular value decomposition `*GELSD` methods.
288+
!!@note BLAS/LAPACK backends do not currently support extended precision (``xdp``).
289+
!!
290+
#:for nd,ndsuf,nde in ALL_RHS
291+
#:for rk,rt,ri in RC_KINDS_TYPES
292+
#:if rk!="xdp"
293+
module subroutine stdlib_linalg_${ri}$_solve_lstsq_${ndsuf}$(a,b,x,real_storage,int_storage,&
294+
#{if rt.startswith('c')}#cmpl_storage,#{endif}#cond,overwrite_a,rank,err)
295+
!> Input matrix a[n,n]
296+
${rt}$, intent(inout), target :: a(:,:)
297+
!> Right hand side vector or array, b[n] or b[n,nrhs]
298+
${rt}$, intent(in) :: b${nd}$
299+
!> Result array/matrix x[n] or x[n,nrhs]
300+
${rt}$, intent(inout), contiguous, target :: x${nd}$
301+
!> [optional] real working storage space
302+
real(${rk}$), optional, intent(inout), target :: real_storage(:)
303+
!> [optional] integer working storage space
304+
integer(ilp), optional, intent(inout), target :: int_storage(:)
305+
#:if rt.startswith('complex')
306+
!> [optional] complex working storage space
307+
${rt}$, optional, intent(inout), target :: cmpl_storage(:)
308+
#:endif
309+
!> [optional] cutoff for rank evaluation: singular values s(i)<=cond*maxval(s) are considered 0.
310+
real(${rk}$), optional, intent(in) :: cond
311+
!> [optional] Can A,b data be overwritten and destroyed?
312+
logical(lk), optional, intent(in) :: overwrite_a
313+
!> [optional] Return rank of A
314+
integer(ilp), optional, intent(out) :: rank
315+
!> [optional] state return flag. On error if not requested, the code will stop
316+
type(linalg_state_type), optional, intent(out) :: err
317+
end subroutine stdlib_linalg_${ri}$_solve_lstsq_${ndsuf}$
318+
#:endif
319+
#:endfor
320+
#:endfor
321+
end interface solve_lstsq
322+
271323
interface det
272324
!! version: experimental
273325
!!

src/stdlib_linalg_least_squares.fypp

Lines changed: 140 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,28 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
1616

1717
contains
1818

19+
elemental subroutine handle_gelsd_info(info,lda,n,ldb,nrhs,err)
20+
integer(ilp), intent(in) :: info,lda,n,ldb,nrhs
21+
type(linalg_state_type), intent(out) :: err
22+
23+
! Process output
24+
select case (info)
25+
case (0)
26+
! Success
27+
case (:-1)
28+
err = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid problem size a=',[lda,n], &
29+
', b=',[ldb,nrhs])
30+
case (1:)
31+
err = linalg_state_type(this,LINALG_ERROR,'SVD did not converge.')
32+
case default
33+
err = linalg_state_type(this,LINALG_INTERNAL_ERROR,'catastrophic error')
34+
35+
end subroutine handle_gelsd_info
36+
1937
#:for rk,rt,ri in RC_KINDS_TYPES
2038
#:if rk!="xdp"
21-
! Workspace needed by gesv
22-
elemental subroutine ${ri}$gesv_space(m,n,nrhs,lrwork,liwork,lcwork)
39+
! Workspace needed by gelsd
40+
elemental subroutine ${ri}$gelsd_space(m,n,nrhs,lrwork,liwork,lcwork)
2341
integer(ilp), intent(in) :: m,n,nrhs
2442
integer(ilp), intent(out) :: lrwork,liwork,lcwork
2543

@@ -53,7 +71,7 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
5371
lcwork = ceiling(1.25*lcwork,kind=ilp)
5472
liwork = ceiling(1.25*liwork,kind=ilp)
5573

56-
end subroutine ${ri}$gesv_space
74+
end subroutine ${ri}$gelsd_space
5775

5876
#:endif
5977
#:endfor
@@ -93,33 +111,87 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
93111
!> [optional] state return flag. On error if not requested, the code will stop
94112
type(linalg_state_type), optional, intent(out) :: err
95113
!> Result array/matrix x[n] or x[n,nrhs]
96-
${rt}$, allocatable, target :: x${nd}$
114+
${rt}$, allocatable, target :: x${nd}$
115+
116+
! Initialize solution with the shape of the rhs
117+
allocate(x,mold=b)
118+
119+
call stdlib_linalg_${ri}$_solve_lstsq_${ndsuf}$(a,b,x,&
120+
cond=cond,overwrite_a=overwrite_a,rank=rank,err=err)
121+
122+
end function stdlib_linalg_${ri}$_lstsq_${ndsuf}$
123+
124+
! Compute the least-squares solution to a real system of linear equations Ax = b
125+
module subroutine stdlib_linalg_${ri}$_solve_lstsq_${ndsuf}$(a,b,x, &
126+
real_storage,int_storage#{if rt.startswith('c')}#,cmpl_storage#{endif}#,cond,overwrite_a,rank,err)
127+
128+
!!### Summary
129+
!! Compute least-squares solution to a real system of linear equations \( Ax = b \)
130+
!!
131+
!!### Description
132+
!!
133+
!! This function computes the least-squares solution of a linear matrix problem.
134+
!!
135+
!! param: a Input matrix of size [m,n].
136+
!! param: b Right-hand-side vector of size [n] or matrix of size [n,nrhs].
137+
!! param: cond [optional] Real input threshold indicating that singular values `s_i <= cond*maxval(s)`
138+
!! do not contribute to the matrix rank.
139+
!! param: overwrite_a [optional] Flag indicating if the input matrix can be overwritten.
140+
!! param: rank [optional] integer flag returning matrix rank.
141+
!! param: err [optional] State return flag.
142+
!! return: x Solution vector of size [n] or solution matrix of size [n,nrhs].
143+
!!
144+
!> Input matrix a[n,n]
145+
${rt}$, intent(inout), target :: a(:,:)
146+
!> Right hand side vector or array, b[n] or b[n,nrhs]
147+
${rt}$, intent(in) :: b${nd}$
148+
!> Result array/matrix x[n] or x[n,nrhs]
149+
${rt}$, intent(inout), contiguous, target :: x${nd}$
150+
!> [optional] real working storage space
151+
real(${rk}$), optional, intent(inout), target :: real_storage(:)
152+
!> [optional] integer working storage space
153+
integer(ilp), optional, intent(inout), target :: int_storage(:)
154+
#:if rt.startswith('c')
155+
!> [optional] complex working storage space
156+
${rt}$, optional, intent(inout), target :: cmpl_storage(:)
157+
#:endif
158+
!> [optional] cutoff for rank evaluation: singular values s(i)<=cond*maxval(s) are considered 0.
159+
real(${rk}$), optional, intent(in) :: cond
160+
!> [optional] Can A,b data be overwritten and destroyed?
161+
logical(lk), optional, intent(in) :: overwrite_a
162+
!> [optional] Return rank of A
163+
integer(ilp), optional, intent(out) :: rank
164+
!> [optional] state return flag. On error if not requested, the code will stop
165+
type(linalg_state_type), optional, intent(out) :: err
97166

98167
!! Local variables
99168
type(linalg_state_type) :: err0
100-
integer(ilp) :: m,n,lda,ldb,nrhs,info,mnmin,mnmax,arank,lrwork,liwork,lcwork
101-
integer(ilp), allocatable :: iwork(:)
169+
integer(ilp) :: m,n,lda,ldb,nrhs,ldx,nrhsx,info,mnmin,mnmax,arank,lrwork,liwork,lcwork
170+
integer(ilp) :: nrs,nis,ncs
171+
integer(ilp), pointer :: iwork(:)
102172
logical(lk) :: copy_a
103173
real(${rk}$) :: acond,rcond
104-
real(${rk}$), allocatable :: singular(:),rwork(:)
105-
${rt}$, pointer :: xmat(:,:),amat(:,:)
106-
${rt}$, allocatable :: cwork(:)
174+
real(${rk}$), allocatable :: singular(:)
175+
real(${rk}$), pointer :: rwork(:)
176+
${rt}$, pointer :: xmat(:,:),amat(:,:),cwork(:)
107177

108178
! Problem sizes
109179
m = size(a,1,kind=ilp)
110180
lda = size(a,1,kind=ilp)
111181
n = size(a,2,kind=ilp)
112182
ldb = size(b,1,kind=ilp)
113183
nrhs = size(b ,kind=ilp)/ldb
184+
ldx = size(x,1,kind=ilp)
185+
nrhsx = size(x ,kind=ilp)/ldx
114186
mnmin = min(m,n)
115187
mnmax = max(m,n)
116188

117-
if (lda<1 .or. n<1 .or. ldb<1 .or. ldb/=m) then
189+
if (lda<1 .or. n<1 .or. ldb<1 .or. ldb/=m .or. ldx/=m) then
118190
err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid sizes: a=',[lda,n], &
119-
'b=',[ldb,nrhs])
120-
allocate(x${nde}$)
191+
'b=',[ldb,nrhs],' x=',[ldx,nrhsx])
121192
call linalg_error_handling(err0,err)
122193
if (present(rank)) rank = 0
194+
return
123195
end if
124196

125197
! Can A be overwritten? By default, do not overwrite
@@ -137,7 +209,7 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
137209
endif
138210

139211
! Initialize solution with the rhs
140-
allocate(x,source=b)
212+
x = b
141213
xmat(1:n,1:nrhs) => x
142214

143215
! Singular values array (in decreasing order)
@@ -153,44 +225,71 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
153225
endif
154226
if (rcond<0) rcond = epsilon(0.0_${rk}$)*mnmax
155227

156-
! Allocate working space
157-
call ${ri}$gesv_space(m,n,nrhs,lrwork,liwork,lcwork)
158-
#:if rt.startswith('complex')
159-
allocate(rwork(lrwork),cwork(lcwork),iwork(liwork))
160-
#:else
161-
allocate(rwork(lrwork),iwork(liwork))
162-
#:endif
228+
! Get working space size
229+
call ${ri}$gelsd_space(m,n,nrhs,lrwork,liwork,lcwork)
230+
231+
! Real working space
232+
if (present(real_storage)) then
233+
rwork => real_storage
234+
else
235+
allocate(rwork(lrwork))
236+
endif
237+
nrs = size(rwork,kind=ilp)
163238

164-
! Solve system using singular value decomposition
165-
#:if rt.startswith('complex')
166-
call gelsd(m,n,nrhs,amat,lda,xmat,ldb,singular,rcond,arank,cwork,lrwork,rwork,iwork,info)
167-
#:else
168-
call gelsd(m,n,nrhs,amat,lda,xmat,ldb,singular,rcond,arank,rwork,lrwork,iwork,info)
169-
#:endif
239+
! Integer working space
240+
if (present(int_storage)) then
241+
iwork => int_storage
242+
else
243+
allocate(iwork(liwork))
244+
endif
245+
nis = size(iwork,kind=ilp)
246+
247+
#:if rt.startswith('complex')
248+
! Complex working space
249+
if (present(cmpl_storage)) then
250+
cwork => cmpl_storage
251+
else
252+
allocate(cwork(lcwork))
253+
endif
254+
ncs = size(iwork,kind=ilp)
255+
#:endif
256+
257+
if (nrs<lrwork .or. nis<liwork#{if rt.startswith('c')}# .or. ncs<lcwork#{endif}#) then
258+
! Halt on insufficient space
259+
err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'insufficient working space: ',&
260+
'real=',nrs,' should be >=',lrwork, &
261+
', int=',nis,' should be >=',liwork &
262+
#{if rt.startswith('complex')}#,', cmplx=',ncs,' should be >=',lcwork#{endif}#)
263+
264+
else
170265

266+
! Solve system using singular value decomposition
267+
call gelsd(m,n,nrhs,amat,lda,xmat,ldb,singular,rcond,arank, &
268+
#:if rt.startswith('complex')
269+
cwork,nrs,rwork,iwork,info)
270+
#:else
271+
rwork,nrs,iwork,info)
272+
#:endif
273+
274+
endif
275+
171276
! The condition number of A in the 2-norm = S(1)/S(min(m,n)).
172277
acond = singular(1)/singular(mnmin)
173278

174279
! Process output
175-
select case (info)
176-
case (0)
177-
! Success
178-
case (:-1)
179-
err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid problem size a=',[lda,n], &
180-
', b=',[ldb,nrhs])
181-
case (1:)
182-
err0 = linalg_state_type(this,LINALG_ERROR,'SVD did not converge.')
183-
case default
184-
err0 = linalg_state_type(this,LINALG_INTERNAL_ERROR,'catastrophic error')
185-
end select
186-
187-
if (copy_a) deallocate(amat)
280+
call handle_gelsd_info(info,lda,n,ldb,nrhs,err0)
188281

189282
! Process output and return
190-
call linalg_error_handling(err0,err)
283+
1 if (copy_a) deallocate(amat)
191284
if (present(rank)) rank = arank
285+
if (.not.present(real_storage)) deallocate(rwork)
286+
if (.not.present(int_storage)) deallocate(iwork)
287+
#:if rt.startswith('complex')
288+
if (.not.present(cmpl_storage)) deallocate(cwork)
289+
#:endif
290+
call linalg_error_handling(err0,err)
192291

193-
end function stdlib_linalg_${ri}$_lstsq_${ndsuf}$
292+
end subroutine stdlib_linalg_${ri}$_solve_lstsq_${ndsuf}$
194293

195294
#:endif
196295
#:endfor

0 commit comments

Comments
 (0)