-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ab8f225
commit 84a2f36
Showing
14 changed files
with
2,907 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,253 @@ | ||
<!DOCTYPE html> | ||
<html class="writer-html5" lang="en" data-content_root="../"> | ||
<head> | ||
<meta charset="utf-8" /><meta name="viewport" content="width=device-width, initial-scale=1" /> | ||
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> | ||
<title>Hello, KAN! — Kolmogorov Arnold Network documentation</title> | ||
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=80d5e7a1" /> | ||
<link rel="stylesheet" type="text/css" href="../_static/css/theme.css?v=19f00094" /> | ||
|
||
|
||
<!--[if lt IE 9]> | ||
<script src="../_static/js/html5shiv.min.js"></script> | ||
<![endif]--> | ||
|
||
<script src="../_static/jquery.js?v=5d32c60e"></script> | ||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script> | ||
<script src="../_static/documentation_options.js?v=5929fcd5"></script> | ||
<script src="../_static/doctools.js?v=9a2dae69"></script> | ||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script> | ||
<script async="async" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> | ||
<script src="../_static/js/theme.js"></script> | ||
<link rel="index" title="Index" href="../genindex.html" /> | ||
<link rel="search" title="Search" href="../search.html" /> | ||
</head> | ||
|
||
<body class="wy-body-for-nav"> | ||
<div class="wy-grid-for-nav"> | ||
<nav data-toggle="wy-nav-shift" class="wy-nav-side"> | ||
<div class="wy-side-scroll"> | ||
<div class="wy-side-nav-search" > | ||
|
||
|
||
|
||
<a href="../index.html" class="icon icon-home"> | ||
Kolmogorov Arnold Network | ||
</a> | ||
<div role="search"> | ||
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get"> | ||
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" /> | ||
<input type="hidden" name="check_keywords" value="yes" /> | ||
<input type="hidden" name="area" value="default" /> | ||
</form> | ||
</div> | ||
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu"> | ||
<p class="caption" role="heading"><span class="caption-text">Contents:</span></p> | ||
<ul> | ||
<li class="toctree-l1"><a class="reference internal" href="../intro.html">Hello, KAN!</a></li> | ||
<li class="toctree-l1"><a class="reference internal" href="../modules.html">API</a></li> | ||
<li class="toctree-l1"><a class="reference internal" href="../demos.html">API Demos</a></li> | ||
<li class="toctree-l1"><a class="reference internal" href="../examples.html">Examples</a></li> | ||
</ul> | ||
|
||
</div> | ||
</div> | ||
</nav> | ||
|
||
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" > | ||
<i data-toggle="wy-nav-top" class="fa fa-bars"></i> | ||
<a href="../index.html">Kolmogorov Arnold Network</a> | ||
</nav> | ||
|
||
<div class="wy-nav-content"> | ||
<div class="rst-content"> | ||
<div role="navigation" aria-label="Page navigation"> | ||
<ul class="wy-breadcrumbs"> | ||
<li><a href="../index.html" class="icon icon-home" aria-label="Home"></a></li> | ||
<li class="breadcrumb-item active">Hello, KAN!</li> | ||
<li class="wy-breadcrumbs-aside"> | ||
<a href="../_sources/.ipynb_checkpoints/intro-checkpoint.rst.txt" rel="nofollow"> View page source</a> | ||
</li> | ||
</ul> | ||
<hr/> | ||
</div> | ||
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article"> | ||
<div itemprop="articleBody"> | ||
|
||
<section id="hello-kan"> | ||
<span id="id1"></span><h1>Hello, KAN!<a class="headerlink" href="#hello-kan" title="Link to this heading"></a></h1> | ||
<section id="kolmogorov-arnold-representation-theorem"> | ||
<h2>Kolmogorov-Arnold representation theorem<a class="headerlink" href="#kolmogorov-arnold-representation-theorem" title="Link to this heading"></a></h2> | ||
<p>Kolmogorov-Arnold representation theorem states that if <span class="math notranslate nohighlight">\(f\)</span> is a | ||
multivariate continuous function on a bounded domain, then it can be | ||
written as a finite composition of continuous functions of a single | ||
variable and the binary operation of addition. More specifically, for a | ||
smooth <span class="math notranslate nohighlight">\(f : [0,1]^n \to \mathbb{R}\)</span>,</p> | ||
<div class="math notranslate nohighlight"> | ||
\[f(x) = f(x_1,...,x_n)=\sum_{q=1}^{2n+1}\Phi_q(\sum_{p=1}^n \phi_{q,p}(x_p))\]</div> | ||
<p>where <span class="math notranslate nohighlight">\(\phi_{q,p}:[0,1]\to\mathbb{R}\)</span> and | ||
<span class="math notranslate nohighlight">\(\Phi_q:\mathbb{R}\to\mathbb{R}\)</span>. In a sense, they showed that the | ||
only true multivariate function is addition, since every other function | ||
can be written using univariate functions and sum. However, this 2-Layer | ||
width-<span class="math notranslate nohighlight">\((2n+1)\)</span> Kolmogorov-Arnold representation may not be smooth | ||
due to its limited expressive power. We augment its expressive power by | ||
generalizing it to arbitrary depths and widths.</p> | ||
</section> | ||
<section id="kolmogorov-arnold-network-kan"> | ||
<h2>Kolmogorov-Arnold Network (KAN)<a class="headerlink" href="#kolmogorov-arnold-network-kan" title="Link to this heading"></a></h2> | ||
<p>The Kolmogorov-Arnold representation can be written in matrix form</p> | ||
<div class="math notranslate nohighlight"> | ||
\[f(x)={\bf \Phi}_{\rm out}\circ{\bf \Phi}_{\rm in}\circ {\bf x}\]</div> | ||
<p>where</p> | ||
<div class="math notranslate nohighlight"> | ||
\[\begin{split}{\bf \Phi}_{\rm in}= \begin{pmatrix} \phi_{1,1}(\cdot) & \cdots & \phi_{1,n}(\cdot) \\ \vdots & & \vdots \\ \phi_{2n+1,1}(\cdot) & \cdots & \phi_{2n+1,n}(\cdot) \end{pmatrix},\quad {\bf \Phi}_{\rm out}=\begin{pmatrix} \Phi_1(\cdot) & \cdots & \Phi_{2n+1}(\cdot)\end{pmatrix}\end{split}\]</div> | ||
<p>We notice that both <span class="math notranslate nohighlight">\({\bf \Phi}_{\rm in}\)</span> and | ||
<span class="math notranslate nohighlight">\({\bf \Phi}_{\rm out}\)</span> are special cases of the following function | ||
matrix <span class="math notranslate nohighlight">\({\bf \Phi}\)</span> (with <span class="math notranslate nohighlight">\(n_{\rm in}\)</span> inputs, and | ||
<span class="math notranslate nohighlight">\(n_{\rm out}\)</span> outputs), we call a Kolmogorov-Arnold layer:</p> | ||
<div class="math notranslate nohighlight"> | ||
\[\begin{split}{\bf \Phi}= \begin{pmatrix} \phi_{1,1}(\cdot) & \cdots & \phi_{1,n_{\rm in}}(\cdot) \\ \vdots & & \vdots \\ \phi_{n_{\rm out},1}(\cdot) & \cdots & \phi_{n_{\rm out},n_{\rm in}}(\cdot) \end{pmatrix}\end{split}\]</div> | ||
<p><span class="math notranslate nohighlight">\({\bf \Phi}_{\rm in}\)</span> corresponds to | ||
<span class="math notranslate nohighlight">\(n_{\rm in}=n, n_{\rm out}=2n+1\)</span>, and <span class="math notranslate nohighlight">\({\bf \Phi}_{\rm out}\)</span> | ||
corresponds to <span class="math notranslate nohighlight">\(n_{\rm in}=2n+1, n_{\rm out}=1\)</span>.</p> | ||
<p>After defining the layer, we can construct a Kolmogorov-Arnold network | ||
simply by stacking layers! Let’s say we have <span class="math notranslate nohighlight">\(L\)</span> layers, with the | ||
<span class="math notranslate nohighlight">\(l^{\rm th}\)</span> layer <span class="math notranslate nohighlight">\({\bf \Phi}_l\)</span> have shape | ||
<span class="math notranslate nohighlight">\((n_{l+1}, n_{l})\)</span>. Then the whole network is</p> | ||
<div class="math notranslate nohighlight"> | ||
\[{\rm KAN}({\bf x})={\bf \Phi}_{L-1}\circ\cdots \circ{\bf \Phi}_1\circ{\bf \Phi}_0\circ {\bf x}\]</div> | ||
<p>In constrast, a Multi-Layer Perceptron is interleaved by linear layers | ||
<span class="math notranslate nohighlight">\({\bf W}_l\)</span> and nonlinearities <span class="math notranslate nohighlight">\(\sigma\)</span>:</p> | ||
<div class="math notranslate nohighlight"> | ||
\[{\rm MLP}({\bf x})={\bf W}_{L-1}\circ\sigma\circ\cdots\circ {\bf W}_1\circ\sigma\circ {\bf W}_0\circ {\bf x}\]</div> | ||
<p>A KAN can be easily visualized. (1) A KAN is simply stack of KAN layers. | ||
(2) Each KAN layer can be visualized as a fully-connected layer, with a | ||
1D function placed on each edge. Let’s see an example below.</p> | ||
</section> | ||
<section id="get-started-with-kans"> | ||
<h2>Get started with KANs<a class="headerlink" href="#get-started-with-kans" title="Link to this heading"></a></h2> | ||
<p>Initialize KAN</p> | ||
<div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span>from kan import * | ||
# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5). | ||
model = KAN(width=[2,5,1], grid=5, k=3, seed=0) | ||
</pre></div> | ||
</div> | ||
<p>Create dataset</p> | ||
<div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span># create dataset f(x,y) = exp(sin(pi*x)+y^2) | ||
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) | ||
dataset = create_dataset(f, n_var=2) | ||
dataset['train_input'].shape, dataset['train_label'].shape | ||
</pre></div> | ||
</div> | ||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">Size</span><span class="p">([</span><span class="mi">1000</span><span class="p">,</span> <span class="mi">2</span><span class="p">]),</span> <span class="n">torch</span><span class="o">.</span><span class="n">Size</span><span class="p">([</span><span class="mi">1000</span><span class="p">,</span> <span class="mi">1</span><span class="p">]))</span> | ||
</pre></div> | ||
</div> | ||
<p>Plot KAN at initialization</p> | ||
<div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span># plot KAN at initialization | ||
model(dataset['train_input']); | ||
model.plot(beta=100) | ||
</pre></div> | ||
</div> | ||
<img alt=".ipynb_checkpoints/intro_files/intro_15_0.png" src=".ipynb_checkpoints/intro_files/intro_15_0.png" /> | ||
<p>Train KAN with sparsity regularization</p> | ||
<div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span># train the model | ||
model.train(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=10.); | ||
</pre></div> | ||
</div> | ||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span>train loss: 1.57e-01 | test loss: 1.31e-01 | reg: 2.05e+01 : 100%|██| 20/20 [00:18<00:00, 1.06it/s] | ||
</pre></div> | ||
</div> | ||
<p>Plot trained KAN</p> | ||
<div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span>model.plot() | ||
</pre></div> | ||
</div> | ||
<img alt=".ipynb_checkpoints/intro_files/intro_19_0.png" src=".ipynb_checkpoints/intro_files/intro_19_0.png" /> | ||
<p>Prune KAN and replot (keep the original shape)</p> | ||
<div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span>model.prune() | ||
model.plot(mask=True) | ||
</pre></div> | ||
</div> | ||
<img alt=".ipynb_checkpoints/intro_files/intro_21_0.png" src=".ipynb_checkpoints/intro_files/intro_21_0.png" /> | ||
<p>Prune KAN and replot (get a smaller shape)</p> | ||
<div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span>model = model.prune() | ||
model(dataset['train_input']) | ||
model.plot() | ||
</pre></div> | ||
</div> | ||
<img alt=".ipynb_checkpoints/intro_files/intro_23_0.png" src=".ipynb_checkpoints/intro_files/intro_23_0.png" /> | ||
<p>Continue training and replot</p> | ||
<div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span>model.train(dataset, opt="LBFGS", steps=50); | ||
</pre></div> | ||
</div> | ||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span>train loss: 4.74e-03 | test loss: 4.80e-03 | reg: 2.98e+00 : 100%|██| 50/50 [00:07<00:00, 7.03it/s] | ||
</pre></div> | ||
</div> | ||
<div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span>model.plot() | ||
</pre></div> | ||
</div> | ||
<img alt=".ipynb_checkpoints/intro_files/intro_26_0.png" src=".ipynb_checkpoints/intro_files/intro_26_0.png" /> | ||
<p>Automatically or manually set activation functions to be symbolic</p> | ||
<div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span>mode = "auto" # "manual" | ||
|
||
if mode == "manual": | ||
# manual mode | ||
model.fix_symbolic(0,0,0,'sin'); | ||
model.fix_symbolic(0,1,0,'x^2'); | ||
model.fix_symbolic(1,0,0,'exp'); | ||
elif mode == "auto": | ||
# automatic mode | ||
lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs'] | ||
model.auto_symbolic(lib=lib) | ||
</pre></div> | ||
</div> | ||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">fixing</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span><span class="mi">0</span><span class="p">,</span><span class="mi">0</span><span class="p">)</span> <span class="k">with</span> <span class="n">sin</span><span class="p">,</span> <span class="n">r2</span><span class="o">=</span><span class="mf">0.999987252534279</span> | ||
<span class="n">fixing</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span><span class="mi">1</span><span class="p">,</span><span class="mi">0</span><span class="p">)</span> <span class="k">with</span> <span class="n">x</span><span class="o">^</span><span class="mi">2</span><span class="p">,</span> <span class="n">r2</span><span class="o">=</span><span class="mf">0.9999996536741071</span> | ||
<span class="n">fixing</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">0</span><span class="p">,</span><span class="mi">0</span><span class="p">)</span> <span class="k">with</span> <span class="n">exp</span><span class="p">,</span> <span class="n">r2</span><span class="o">=</span><span class="mf">0.9999988529417926</span> | ||
</pre></div> | ||
</div> | ||
<p>Continue training to almost machine precision</p> | ||
<div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span>model.train(dataset, opt="LBFGS", steps=50); | ||
</pre></div> | ||
</div> | ||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span>train loss: 2.02e-10 | test loss: 1.13e-10 | reg: 2.98e+00 : 100%|██| 50/50 [00:02<00:00, 22.59it/s] | ||
</pre></div> | ||
</div> | ||
<p>Obtain the symbolic formula</p> | ||
<div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span>model.symbolic_formula()[0][0] | ||
</pre></div> | ||
</div> | ||
<div class="math notranslate nohighlight"> | ||
\[\displaystyle 1.0 e^{1.0 x_{2}^{2} + 1.0 \sin{\left(3.14 x_{1} \right)}}\]</div> | ||
</section> | ||
</section> | ||
|
||
|
||
</div> | ||
</div> | ||
<footer> | ||
|
||
<hr/> | ||
|
||
<div role="contentinfo"> | ||
<p>© Copyright 2024, Ziming Liu.</p> | ||
</div> | ||
|
||
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a | ||
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a> | ||
provided by <a href="https://readthedocs.org">Read the Docs</a>. | ||
|
||
|
||
</footer> | ||
</div> | ||
</div> | ||
</section> | ||
</div> | ||
<script> | ||
jQuery(function () { | ||
SphinxRtdTheme.Navigation.enable(true); | ||
}); | ||
</script> | ||
|
||
</body> | ||
</html> |
Oops, something went wrong.