Skip to content

Commit

Permalink
Fix input types, improve readability (#369)
Browse files Browse the repository at this point in the history
* Fix input types, improve readability

* Add missing bit

* Add doc string

* Fix mistake

* Add docstring to docs

* Reformulate

* Bump version

* Update src/matrix/kernelkroneckermat.jl

Co-authored-by: Théo Galy-Fajou <[email protected]>

* Bump version further, rename api section

* Apply format suggestions from code review

Co-authored-by: willtebbutt <[email protected]>

* Improve error handling.

Co-authored-by: st-- <[email protected]>

* Formatter

Co-authored-by: Théo Galy-Fajou <[email protected]>
Co-authored-by: willtebbutt <[email protected]>
Co-authored-by: st-- <[email protected]>
  • Loading branch information
4 people authored Sep 29, 2021
1 parent 3356fa6 commit 6e7ca17
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 38 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.10.19"
version = "0.10.20"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
9 changes: 9 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,12 @@ kernelpdmat
nystrom
NystromFact
```

## Conditional Utilities
To keep the dependencies of KernelFunctions lean, some functionality is only available if specific other packages are explicitly loaded (`using`).

### Kronecker.jl
[*https://github.com/MichielStock/Kronecker.jl*](https://github.com/MichielStock/Kronecker.jl)
```@docs
kronecker_kernelmatrix
```
34 changes: 23 additions & 11 deletions src/matrix/kernelkroneckermat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,43 @@ end
"""
@inline iskroncompatible::Kernel) = false # Default return for kernels

function _kernelmatrix_kroneckerjl_helper(::MOInputIsotopicByFeatures, Kfeatures, Koutputs)
function _kernelmatrix_kroneckerjl_helper(
::Type{<:MOInputIsotopicByFeatures}, Kfeatures, Koutputs
)
return Kronecker.kronecker(Kfeatures, Koutputs)
end

function _kernelmatrix_kroneckerjl_helper(::MOInputIsotopicByOutputs, Kfeatures, Koutputs)
function _kernelmatrix_kroneckerjl_helper(
::Type{<:MOInputIsotopicByOutputs}, Kfeatures, Koutputs
)
return Kronecker.kronecker(Koutputs, Kfeatures)
end

"""
kronecker_kernelmatrix(
k::Union{IndependentMOKernel,IntrinsicCoregionMOKernel}, x::MOI, y::MOI
) where {MOI<:IsotopicMOInputsUnion}
Requires Kronecker.jl: Computes the `kernelmatrix` for the `IndependentMOKernel` and the
`IntrinsicCoregionMOKernel`, but returns a lazy kronecker product. This object can be very
efficiently inverted or decomposed. See also [`kernelmatrix`](@ref).
"""
function kronecker_kernelmatrix(
k::Union{IndependentMOKernel,IntrinsicCoregionMOKernel},
x::IsotopicMOInputsUnion,
y::IsotopicMOInputsUnion,
)
@assert x.out_dim == y.out_dim
k::Union{IndependentMOKernel,IntrinsicCoregionMOKernel}, x::MOI, y::MOI
) where {MOI<:IsotopicMOInputsUnion}
x.out_dim == y.out_dim ||
throw(DimensionMismatch("`x` and `y` must have the same `out_dim`"))
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kroneckerjl_helper(x, Kfeatures, Koutputs)
return _kernelmatrix_kroneckerjl_helper(MOI, Kfeatures, Koutputs)
end

function kronecker_kernelmatrix(
k::Union{IndependentMOKernel,IntrinsicCoregionMOKernel}, x::IsotopicMOInputsUnion
)
k::Union{IndependentMOKernel,IntrinsicCoregionMOKernel}, x::MOI
) where {MOI<:IsotopicMOInputsUnion}
Kfeatures = kernelmatrix(k.kernel, x.x)
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kroneckerjl_helper(x, Kfeatures, Koutputs)
return _kernelmatrix_kroneckerjl_helper(MOI, Kfeatures, Koutputs)
end

function kronecker_kernelmatrix(
Expand Down
21 changes: 10 additions & 11 deletions src/mokernels/independent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,24 @@ end
_mo_output_covariance(k::IndependentMOKernel, out_dim) = Eye{Bool}(out_dim)

function kernelmatrix(
k::IndependentMOKernel, x::IsotopicMOInputsUnion, y::IsotopicMOInputsUnion
)
@assert x.out_dim == y.out_dim
k::IndependentMOKernel, x::MOI, y::MOI
) where {MOI<:IsotopicMOInputsUnion}
x.out_dim == y.out_dim ||
throw(DimensionMismatch("`x` and `y` must have the same `out_dim`"))
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kron_helper(x, Kfeatures, Koutputs)
return _kernelmatrix_kron_helper(MOI, Kfeatures, Koutputs)
end

if VERSION >= v"1.6"
function kernelmatrix!(
K::AbstractMatrix,
k::IndependentMOKernel,
x::IsotopicMOInputsUnion,
y::IsotopicMOInputsUnion,
)
@assert x.out_dim == y.out_dim
K::AbstractMatrix, k::IndependentMOKernel, x::MOI, y::MOI
) where {MOI<:IsotopicMOInputsUnion}
x.out_dim == y.out_dim ||
throw(DimensionMismatch("`x` and `y` must have the same `out_dim`"))
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kron_helper!(K, x, Kfeatures, Koutputs)
return _kernelmatrix_kron_helper!(K, MOI, Kfeatures, Koutputs)
end
end

Expand Down
21 changes: 10 additions & 11 deletions src/mokernels/intrinsiccoregion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,25 +48,24 @@ function _mo_output_covariance(k::IntrinsicCoregionMOKernel, out_dim)
end

function kernelmatrix(
k::IntrinsicCoregionMOKernel, x::IsotopicMOInputsUnion, y::IsotopicMOInputsUnion
)
@assert x.out_dim == y.out_dim
k::IntrinsicCoregionMOKernel, x::MOI, y::MOI
) where {MOI<:IsotopicMOInputsUnion}
x.out_dim == y.out_dim ||
throw(DimensionMismatch("`x` and `y` must have the same `out_dim`"))
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kron_helper(x, Kfeatures, Koutputs)
return _kernelmatrix_kron_helper(MOI, Kfeatures, Koutputs)
end

if VERSION >= v"1.6"
function kernelmatrix!(
K::AbstractMatrix,
k::IntrinsicCoregionMOKernel,
x::IsotopicMOInputsUnion,
y::IsotopicMOInputsUnion,
)
@assert x.out_dim == y.out_dim
K::AbstractMatrix, k::IntrinsicCoregionMOKernel, x::MOI, y::MOI
) where {MOI<:IsotopicMOInputsUnion}
x.out_dim == y.out_dim ||
throw(DimensionMismatch("`x` and `y` must have the same `out_dim`"))
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kron_helper!(K, x, Kfeatures, Koutputs)
return _kernelmatrix_kron_helper!(K, MOI, Kfeatures, Koutputs)
end
end

Expand Down
12 changes: 8 additions & 4 deletions src/mokernels/mokernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,24 @@ Abstract type for kernels with multiple outpus.
"""
abstract type MOKernel <: Kernel end

function _kernelmatrix_kron_helper(::MOInputIsotopicByFeatures, Kfeatures, Koutputs)
function _kernelmatrix_kron_helper(::Type{<:MOInputIsotopicByFeatures}, Kfeatures, Koutputs)
return kron(Kfeatures, Koutputs)
end

function _kernelmatrix_kron_helper(::MOInputIsotopicByOutputs, Kfeatures, Koutputs)
function _kernelmatrix_kron_helper(::Type{<:MOInputIsotopicByOutputs}, Kfeatures, Koutputs)
return kron(Koutputs, Kfeatures)
end

if VERSION >= v"1.6"
function _kernelmatrix_kron_helper!(K, ::MOInputIsotopicByFeatures, Kfeatures, Koutputs)
function _kernelmatrix_kron_helper!(
K, ::Type{<:MOInputIsotopicByFeatures}, Kfeatures, Koutputs
)
return kron!(K, Kfeatures, Koutputs)
end

function _kernelmatrix_kron_helper!(K, ::MOInputIsotopicByOutputs, Kfeatures, Koutputs)
function _kernelmatrix_kron_helper!(
K, ::Type{<:MOInputIsotopicByOutputs}, Kfeatures, Koutputs
)
return kron!(K, Koutputs, Kfeatures)
end
end

2 comments on commit 6e7ca17

@Crown421
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/45780

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.20 -m "<description of version>" 6e7ca17987d3e0a8d7f1724b8639c438befcb2d0
git push origin v0.10.20

Please sign in to comment.