Skip to content

Commit

Permalink
Add additional assert for input arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Apr 21, 2024
1 parent 3dabc56 commit 363d177
Showing 1 changed file with 77 additions and 27 deletions.
104 changes: 77 additions & 27 deletions src/wrapper/qr_mumps_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,21 +186,31 @@ for (fname, lname, elty, subty) in (("sqrm_solve_c", libsqrm, Float32 , Float3
@eval begin
function qrm_solve!(spfct :: qrm_spfct{$elty}, b :: Vector{$elty}, x :: Vector{$elty}; transp :: Char='n')
nrhs = 1
if transp == 'n'
@assert length(x) == spfct.fct.n
else
@assert length(x) == spfct.fct.m
end
err = ccall(($fname, $lname), Cint, (Ref{c_spfct{$elty}}, UInt8, Ptr{$elty}, Ptr{$elty}, Cint), spfct, transp, b, x, nrhs)
(err 0) && throw(ErrorException(error_handling(err)))
return nothing
end

function qrm_solve!(spfct :: qrm_spfct{$elty}, b :: Matrix{$elty}, x :: Matrix{$elty}; transp :: Char='n')
nrhs = size(b, 2)
if transp == 'n'
@assert size(x) == (spfct.fct.n, nrhs)
else
@assert size(x) == (spfct.fct.m, nrhs)
end
err = ccall(($fname, $lname), Cint, (Ref{c_spfct{$elty}}, UInt8, Ptr{$elty}, Ptr{$elty}, Cint), spfct, transp, b, x, nrhs)
(err 0) && throw(ErrorException(error_handling(err)))
return nothing
end

function qrm_solve(spfct :: qrm_spfct{$elty}, b :: Vector{$elty}; transp :: Char='n')
nrhs = 1
if(transp=='n')
if transp == 'n'
x = zeros($elty, spfct.fct.n)
else
x = zeros($elty, spfct.fct.m)
Expand All @@ -212,7 +222,7 @@ for (fname, lname, elty, subty) in (("sqrm_solve_c", libsqrm, Float32 , Float3

function qrm_solve(spfct :: qrm_spfct{$elty}, b :: Matrix{$elty}; transp :: Char='n')
nrhs = size(b, 2)
if(transp=='n')
if transp == 'n'
x = zeros($elty, spfct.fct.n, nrhs)
else
x = zeros($elty, spfct.fct.m, nrhs)
Expand Down Expand Up @@ -374,13 +384,23 @@ for (fname, lname, elty, subty) in (("sqrm_spbackslash_c", libsqrm, Float32 ,
@eval begin
function qrm_spbackslash!(spmat :: qrm_spmat{$elty}, b :: Vector{$elty}, x :: Vector{$elty}; transp :: Char='n')
nrhs = 1
if transp == 'n'
@assert length(x) == spmat.mat.n
else
@assert length(x) == spmat.mat.m
end
err = ccall(($fname, $lname), Cint, (Ref{c_spmat{$elty}}, Ptr{$elty}, Ptr{$elty}, Cint, UInt8), spmat, b, x, nrhs, transp)
(err 0) && throw(ErrorException(error_handling(err)))
return nothing
end

function qrm_spbackslash!(spmat :: qrm_spmat{$elty}, b :: Matrix{$elty}, x :: Matrix{$elty}; transp :: Char='n')
nrhs = size(b, 2)
if transp == 'n'
@assert size(x) == (spmat.mat.n, nrhs)
else
@assert size(x) == (spmat.mat.m, nrhs)
end
err = ccall(($fname, $lname), Cint, (Ref{c_spmat{$elty}}, Ptr{$elty}, Ptr{$elty}, Cint, UInt8), spmat, b, x, nrhs, transp)
(err 0) && throw(ErrorException(error_handling(err)))
return nothing
Expand All @@ -389,7 +409,7 @@ for (fname, lname, elty, subty) in (("sqrm_spbackslash_c", libsqrm, Float32 ,

function qrm_spbackslash(spmat :: qrm_spmat{$elty}, b :: Vector{$elty}; transp :: Char='n')
nrhs = 1
if(transp=='n')
if transp == 'n'
x = zeros($elty, spmat.mat.n)
else
x = zeros($elty, spmat.mat.m)
Expand All @@ -402,7 +422,7 @@ for (fname, lname, elty, subty) in (("sqrm_spbackslash_c", libsqrm, Float32 ,

function qrm_spbackslash(spmat :: qrm_spmat{$elty}, b :: Matrix{$elty}; transp :: Char='n')
nrhs = size(b, 2)
if(transp=='n')
if transp == 'n'
x = zeros($elty, spmat.mat.n, nrhs)
else
x = zeros($elty, spmat.mat.m, nrhs)
Expand Down Expand Up @@ -446,13 +466,23 @@ for (fname, lname, elty, subty) in (("sqrm_spfct_backslash_c", libsqrm, Float32
@eval begin
function qrm_spbackslash!(spfct :: qrm_spfct{$elty}, b :: Vector{$elty}, x :: Vector{$elty}; transp :: Char='n')
nrhs = 1
if transp == 'n'
@assert length(x) == spfct.fct.n
else
@assert length(x) == spfct.fct.m
end
err = ccall(($fname, $lname), Cint, (Ref{c_spfct{$elty}}, Ptr{$elty}, Ptr{$elty}, Cint, UInt8), spfct, b, x, nrhs, transp)
(err 0) && throw(ErrorException(error_handling(err)))
return nothing
end

function qrm_spbackslash!(spfct :: qrm_spfct{$elty}, b :: Matrix{$elty}, x :: Matrix{$elty}; transp :: Char='n')
nrhs = size(b, 2)
if transp == 'n'
@assert size(x) == (spfct.fct.n, nrhs)
else
@assert size(x) == (spfct.fct.m, nrhs)
end
err = ccall(($fname, $lname), Cint, (Ref{c_spfct{$elty}}, Ptr{$elty}, Ptr{$elty}, Cint, UInt8), spfct, b, x, nrhs, transp)
(err 0) && throw(ErrorException(error_handling(err)))
return nothing
Expand All @@ -461,10 +491,10 @@ for (fname, lname, elty, subty) in (("sqrm_spfct_backslash_c", libsqrm, Float32

function qrm_spbackslash(spfct :: qrm_spfct{$elty}, b :: Vector{$elty}; transp :: Char='n')
nrhs = 1
if(transp=='n')
x = zeros($elty, spfct.fct.n)
if transp == 'n'
x = zeros($elty, spfct.fct.n)
else
x = zeros($elty, spfct.fct.m)
x = zeros($elty, spfct.fct.m)
end
bcopy = (spfct.fct.m spfct.fct.n) ? copy(b) : b
err = ccall(($fname, $lname), Cint, (Ref{c_spfct{$elty}}, Ptr{$elty}, Ptr{$elty}, Cint, UInt8), spfct, bcopy, x, nrhs, transp)
Expand All @@ -474,10 +504,10 @@ for (fname, lname, elty, subty) in (("sqrm_spfct_backslash_c", libsqrm, Float32

function qrm_spbackslash(spfct :: qrm_spfct{$elty}, b :: Matrix{$elty}; transp :: Char='n')
nrhs = size(b, 2)
if(transp=='n')
x = zeros($elty, spfct.fct.n, nrhs)
if transp == 'n'
x = zeros($elty, spfct.fct.n, nrhs)
else
x = zeros($elty, spfct.fct.m, nrhs)
x = zeros($elty, spfct.fct.m, nrhs)
end
bcopy = (spfct.fct.m spfct.fct.n) ? copy(b) : b
err = ccall(($fname, $lname), Cint, (Ref{c_spfct{$elty}}, Ptr{$elty}, Ptr{$elty}, Cint, UInt8), spfct, bcopy, x, nrhs, transp)
Expand Down Expand Up @@ -555,24 +585,34 @@ for (fname, lname, elty, subty) in (("sqrm_least_squares_c", libsqrm, Float32
@eval begin
function qrm_least_squares!(spmat :: qrm_spmat{$elty}, b :: Vector{$elty}, x :: Vector{$elty}; transp :: Char='n')
nrhs = 1
if transp == 'n'
@assert length(x) == spmat.mat.n
else
@assert length(x) == spmat.mat.m
end
err = ccall(($fname, $lname), Cint, (Ref{c_spmat{$elty}}, Ptr{$elty}, Ptr{$elty}, Cint, UInt8), spmat, b, x, nrhs, transp)
(err 0) && throw(ErrorException(error_handling(err)))
return nothing
end

function qrm_least_squares!(spmat :: qrm_spmat{$elty}, b :: Matrix{$elty}, x :: Matrix{$elty}; transp :: Char='n')
nrhs = size(b, 2)
if transp == 'n'
@assert size(x) == (spmat.mat.n, nrhs)
else
@assert size(x) == (spmat.mat.m, nrhs)
end
err = ccall(($fname, $lname), Cint, (Ref{c_spmat{$elty}}, Ptr{$elty}, Ptr{$elty}, Cint, UInt8), spmat, b, x, nrhs, transp)
(err 0) && throw(ErrorException(error_handling(err)))
return nothing
end

function qrm_least_squares(spmat :: qrm_spmat{$elty}, b :: Vector{$elty}; transp :: Char='n')
nrhs = 1
if(transp=='n')
x = zeros($elty, spmat.mat.n)
if transp == 'n'
x = zeros($elty, spmat.mat.n)
else
x = zeros($elty, spmat.mat.m)
x = zeros($elty, spmat.mat.m)
end
bcopy = copy(b)
err = ccall(($fname, $lname), Cint, (Ref{c_spmat{$elty}}, Ptr{$elty}, Ptr{$elty}, Cint, UInt8), spmat, bcopy, x, nrhs, transp)
Expand All @@ -582,10 +622,10 @@ for (fname, lname, elty, subty) in (("sqrm_least_squares_c", libsqrm, Float32

function qrm_least_squares(spmat :: qrm_spmat{$elty}, b :: Matrix{$elty}; transp :: Char='n')
nrhs = size(b, 2)
if(transp=='n')
x = zeros($elty, spmat.mat.n, nrhs)
if transp == 'n'
x = zeros($elty, spmat.mat.n, nrhs)
else
x = zeros($elty, spmat.mat.n, nrhs)
x = zeros($elty, spmat.mat.n, nrhs)
end
bcopy = copy(b)
err = ccall(($fname, $lname), Cint, (Ref{c_spmat{$elty}}, Ptr{$elty}, Ptr{$elty}, Cint, UInt8), spmat, bcopy, x, nrhs, transp)
Expand Down Expand Up @@ -615,24 +655,34 @@ for (fname, lname, elty, subty) in (("sqrm_min_norm_c", libsqrm, Float32 , Flo
@eval begin
function qrm_min_norm!(spmat :: qrm_spmat{$elty}, b :: Vector{$elty}, x :: Vector{$elty}; transp :: Char='n')
nrhs = 1
if transp == 'n'
@assert length(x) == spmat.mat.n
else
@assert length(x) == spmat.mat.m
end
err = ccall(($fname, $lname), Cint, (Ref{c_spmat{$elty}}, Ptr{$elty}, Ptr{$elty}, Cint, UInt8), spmat, b, x, nrhs, transp)
(err 0) && throw(ErrorException(error_handling(err)))
return nothing
end

function qrm_min_norm!(spmat :: qrm_spmat{$elty}, b :: Matrix{$elty}, x :: Matrix{$elty}; transp :: Char='n')
nrhs = size(b, 2)
if transp == 'n'
@assert size(x) == (spmat.mat.n, nrhs)
else
@assert size(x) == (spmat.mat.m, nrhs)
end
err = ccall(($fname, $lname), Cint, (Ref{c_spmat{$elty}}, Ptr{$elty}, Ptr{$elty}, Cint, UInt8), spmat, b, x, nrhs, transp)
(err 0) && throw(ErrorException(error_handling(err)))
return nothing
end

function qrm_min_norm(spmat :: qrm_spmat{$elty}, b :: Vector{$elty}; transp :: Char='n')
nrhs = 1
if(transp=='n')
x = zeros($elty, spmat.mat.n)
if transp == 'n'
x = zeros($elty, spmat.mat.n)
else
x = zeros($elty, spmat.mat.m)
x = zeros($elty, spmat.mat.m)
end
err = ccall(($fname, $lname), Cint, (Ref{c_spmat{$elty}}, Ptr{$elty}, Ptr{$elty}, Cint, UInt8), spmat, b, x, nrhs, transp)
(err 0) && throw(ErrorException(error_handling(err)))
Expand All @@ -641,10 +691,10 @@ for (fname, lname, elty, subty) in (("sqrm_min_norm_c", libsqrm, Float32 , Flo

function qrm_min_norm(spmat :: qrm_spmat{$elty}, b :: Matrix{$elty}; transp :: Char='n')
nrhs = size(b, 2)
if(transp=='n')
x = zeros($elty, spmat.mat.n, nrhs)
if transp == 'n'
x = zeros($elty, spmat.mat.n, nrhs)
else
x = zeros($elty, spmat.mat.n, nrhs)
x = zeros($elty, spmat.mat.n, nrhs)
end
err = ccall(($fname, $lname), Cint, (Ref{c_spmat{$elty}}, Ptr{$elty}, Ptr{$elty}, Cint, UInt8), spmat, b, x, nrhs, transp)
(err 0) && throw(ErrorException(error_handling(err)))
Expand Down Expand Up @@ -697,11 +747,11 @@ for (fname, lname, elty, subty) in (("sqrm_residual_norm_c", libsqrm, Float32
@inline qrm_residual_norm!(spmat :: Transpose{$elty,qrm_spmat{$elty}}, b :: Matrix{$elty}, x :: Matrix{$elty}, nrm :: Vector{$subty}) = qrm_residual_norm!(spmat.parent, b, x, transp='t')
@inline qrm_residual_norm!(spmat :: Adjoint{$elty,qrm_spmat{$elty}} , b :: Matrix{$elty}, x :: Matrix{$elty}, nrm :: Vector{$subty}) = qrm_residual_norm!(spmat.parent, b, x, transp='c')

@inline qrm_residual_norm(spmat :: Transpose{$elty,qrm_spmat{$elty}}, b :: Vector{$elty}, x :: Vector{$elty}) = qrm_residual_norm(spmat.parent, b, x, transp='t')
@inline qrm_residual_norm(spmat :: Transpose{$elty,qrm_spmat{$elty}}, b :: Matrix{$elty}, x :: Matrix{$elty}) = qrm_residual_norm(spmat.parent, b, x, transp='t')
@inline qrm_residual_norm(spmat :: Transpose{$elty,qrm_spmat{$elty}}, b :: Vector{$elty}, x :: Vector{$elty}) = qrm_residual_norm(spmat.parent, b, x, transp='t')
@inline qrm_residual_norm(spmat :: Transpose{$elty,qrm_spmat{$elty}}, b :: Matrix{$elty}, x :: Matrix{$elty}) = qrm_residual_norm(spmat.parent, b, x, transp='t')

@inline qrm_residual_norm(spmat :: Adjoint{$elty,qrm_spmat{$elty}} , b :: Vector{$elty}, x :: Vector{$elty}) = qrm_residual_norm(spmat.parent, b, x, transp='c')
@inline qrm_residual_norm(spmat :: Adjoint{$elty,qrm_spmat{$elty}} , b :: Matrix{$elty}, x :: Matrix{$elty}) = qrm_residual_norm(spmat.parent, b, x, transp='c')
@inline qrm_residual_norm(spmat :: Adjoint{$elty,qrm_spmat{$elty}}, b :: Vector{$elty}, x :: Vector{$elty}) = qrm_residual_norm(spmat.parent, b, x, transp='c')
@inline qrm_residual_norm(spmat :: Adjoint{$elty,qrm_spmat{$elty}}, b :: Matrix{$elty}, x :: Matrix{$elty}) = qrm_residual_norm(spmat.parent, b, x, transp='c')


end
Expand Down Expand Up @@ -747,7 +797,7 @@ for (fname, lname, elty, subty) in (("sqrm_residual_orth_c", libsqrm, Float32
end

function qrm_set(str :: String, val :: Number)
if (str GICNTL) || (str PICNTL)
if (str GICNTL) || (str PICNTL)
err = ccall(("qrm_glob_set_i4_c", libqrm_common), Cint, (Cstring, Cint), str, val)
elseif str RCNTL
err = ccall(("qrm_glob_set_r4_c", libqrm_common), Cint, (Cstring, Cfloat), str, val)
Expand Down

0 comments on commit 363d177

Please sign in to comment.