Skip to content

Commit

Permalink
deploy: 8b28a28
Browse files Browse the repository at this point in the history
  • Loading branch information
KindXiaoming committed Aug 11, 2024
1 parent ab8f225 commit 84a2f36
Show file tree
Hide file tree
Showing 14 changed files with 2,907 additions and 43 deletions.
253 changes: 253 additions & 0 deletions .ipynb_checkpoints/intro-checkpoint.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
<!DOCTYPE html>
<html class="writer-html5" lang="en" data-content_root="../">
<head>
<meta charset="utf-8" /><meta name="viewport" content="width=device-width, initial-scale=1" />

<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Hello, KAN! &mdash; Kolmogorov Arnold Network documentation</title>
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=80d5e7a1" />
<link rel="stylesheet" type="text/css" href="../_static/css/theme.css?v=19f00094" />


<!--[if lt IE 9]>
<script src="../_static/js/html5shiv.min.js"></script>
<![endif]-->

<script src="../_static/jquery.js?v=5d32c60e"></script>
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
<script src="../_static/doctools.js?v=9a2dae69"></script>
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
<script async="async" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
<script src="../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../genindex.html" />
<link rel="search" title="Search" href="../search.html" />
</head>

<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >



<a href="../index.html" class="icon icon-home">
Kolmogorov Arnold Network
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
<p class="caption" role="heading"><span class="caption-text">Contents:</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../intro.html">Hello, KAN!</a></li>
<li class="toctree-l1"><a class="reference internal" href="../modules.html">API</a></li>
<li class="toctree-l1"><a class="reference internal" href="../demos.html">API Demos</a></li>
<li class="toctree-l1"><a class="reference internal" href="../examples.html">Examples</a></li>
</ul>

</div>
</div>
</nav>

<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../index.html">Kolmogorov Arnold Network</a>
</nav>

<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="Page navigation">
<ul class="wy-breadcrumbs">
<li><a href="../index.html" class="icon icon-home" aria-label="Home"></a></li>
<li class="breadcrumb-item active">Hello, KAN!</li>
<li class="wy-breadcrumbs-aside">
<a href="../_sources/.ipynb_checkpoints/intro-checkpoint.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">

<section id="hello-kan">
<span id="id1"></span><h1>Hello, KAN!<a class="headerlink" href="#hello-kan" title="Link to this heading"></a></h1>
<section id="kolmogorov-arnold-representation-theorem">
<h2>Kolmogorov-Arnold representation theorem<a class="headerlink" href="#kolmogorov-arnold-representation-theorem" title="Link to this heading"></a></h2>
<p>Kolmogorov-Arnold representation theorem states that if <span class="math notranslate nohighlight">\(f\)</span> is a
multivariate continuous function on a bounded domain, then it can be
written as a finite composition of continuous functions of a single
variable and the binary operation of addition. More specifically, for a
smooth <span class="math notranslate nohighlight">\(f : [0,1]^n \to \mathbb{R}\)</span>,</p>
<div class="math notranslate nohighlight">
\[f(x) = f(x_1,...,x_n)=\sum_{q=1}^{2n+1}\Phi_q(\sum_{p=1}^n \phi_{q,p}(x_p))\]</div>
<p>where <span class="math notranslate nohighlight">\(\phi_{q,p}:[0,1]\to\mathbb{R}\)</span> and
<span class="math notranslate nohighlight">\(\Phi_q:\mathbb{R}\to\mathbb{R}\)</span>. In a sense, they showed that the
only true multivariate function is addition, since every other function
can be written using univariate functions and sum. However, this 2-Layer
width-<span class="math notranslate nohighlight">\((2n+1)\)</span> Kolmogorov-Arnold representation may not be smooth
due to its limited expressive power. We augment its expressive power by
generalizing it to arbitrary depths and widths.</p>
</section>
<section id="kolmogorov-arnold-network-kan">
<h2>Kolmogorov-Arnold Network (KAN)<a class="headerlink" href="#kolmogorov-arnold-network-kan" title="Link to this heading"></a></h2>
<p>The Kolmogorov-Arnold representation can be written in matrix form</p>
<div class="math notranslate nohighlight">
\[f(x)={\bf \Phi}_{\rm out}\circ{\bf \Phi}_{\rm in}\circ {\bf x}\]</div>
<p>where</p>
<div class="math notranslate nohighlight">
\[\begin{split}{\bf \Phi}_{\rm in}= \begin{pmatrix} \phi_{1,1}(\cdot) &amp; \cdots &amp; \phi_{1,n}(\cdot) \\ \vdots &amp; &amp; \vdots \\ \phi_{2n+1,1}(\cdot) &amp; \cdots &amp; \phi_{2n+1,n}(\cdot) \end{pmatrix},\quad {\bf \Phi}_{\rm out}=\begin{pmatrix} \Phi_1(\cdot) &amp; \cdots &amp; \Phi_{2n+1}(\cdot)\end{pmatrix}\end{split}\]</div>
<p>We notice that both <span class="math notranslate nohighlight">\({\bf \Phi}_{\rm in}\)</span> and
<span class="math notranslate nohighlight">\({\bf \Phi}_{\rm out}\)</span> are special cases of the following function
matrix <span class="math notranslate nohighlight">\({\bf \Phi}\)</span> (with <span class="math notranslate nohighlight">\(n_{\rm in}\)</span> inputs, and
<span class="math notranslate nohighlight">\(n_{\rm out}\)</span> outputs), we call a Kolmogorov-Arnold layer:</p>
<div class="math notranslate nohighlight">
\[\begin{split}{\bf \Phi}= \begin{pmatrix} \phi_{1,1}(\cdot) &amp; \cdots &amp; \phi_{1,n_{\rm in}}(\cdot) \\ \vdots &amp; &amp; \vdots \\ \phi_{n_{\rm out},1}(\cdot) &amp; \cdots &amp; \phi_{n_{\rm out},n_{\rm in}}(\cdot) \end{pmatrix}\end{split}\]</div>
<p><span class="math notranslate nohighlight">\({\bf \Phi}_{\rm in}\)</span> corresponds to
<span class="math notranslate nohighlight">\(n_{\rm in}=n, n_{\rm out}=2n+1\)</span>, and <span class="math notranslate nohighlight">\({\bf \Phi}_{\rm out}\)</span>
corresponds to <span class="math notranslate nohighlight">\(n_{\rm in}=2n+1, n_{\rm out}=1\)</span>.</p>
<p>After defining the layer, we can construct a Kolmogorov-Arnold network
simply by stacking layers! Let’s say we have <span class="math notranslate nohighlight">\(L\)</span> layers, with the
<span class="math notranslate nohighlight">\(l^{\rm th}\)</span> layer <span class="math notranslate nohighlight">\({\bf \Phi}_l\)</span> have shape
<span class="math notranslate nohighlight">\((n_{l+1}, n_{l})\)</span>. Then the whole network is</p>
<div class="math notranslate nohighlight">
\[{\rm KAN}({\bf x})={\bf \Phi}_{L-1}\circ\cdots \circ{\bf \Phi}_1\circ{\bf \Phi}_0\circ {\bf x}\]</div>
<p>In constrast, a Multi-Layer Perceptron is interleaved by linear layers
<span class="math notranslate nohighlight">\({\bf W}_l\)</span> and nonlinearities <span class="math notranslate nohighlight">\(\sigma\)</span>:</p>
<div class="math notranslate nohighlight">
\[{\rm MLP}({\bf x})={\bf W}_{L-1}\circ\sigma\circ\cdots\circ {\bf W}_1\circ\sigma\circ {\bf W}_0\circ {\bf x}\]</div>
<p>A KAN can be easily visualized. (1) A KAN is simply stack of KAN layers.
(2) Each KAN layer can be visualized as a fully-connected layer, with a
1D function placed on each edge. Let’s see an example below.</p>
</section>
<section id="get-started-with-kans">
<h2>Get started with KANs<a class="headerlink" href="#get-started-with-kans" title="Link to this heading"></a></h2>
<p>Initialize KAN</p>
<div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span>from kan import *
# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).
model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
</pre></div>
</div>
<p>Create dataset</p>
<div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span># create dataset f(x,y) = exp(sin(pi*x)+y^2)
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)
dataset[&#39;train_input&#39;].shape, dataset[&#39;train_label&#39;].shape
</pre></div>
</div>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">Size</span><span class="p">([</span><span class="mi">1000</span><span class="p">,</span> <span class="mi">2</span><span class="p">]),</span> <span class="n">torch</span><span class="o">.</span><span class="n">Size</span><span class="p">([</span><span class="mi">1000</span><span class="p">,</span> <span class="mi">1</span><span class="p">]))</span>
</pre></div>
</div>
<p>Plot KAN at initialization</p>
<div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span># plot KAN at initialization
model(dataset[&#39;train_input&#39;]);
model.plot(beta=100)
</pre></div>
</div>
<img alt=".ipynb_checkpoints/intro_files/intro_15_0.png" src=".ipynb_checkpoints/intro_files/intro_15_0.png" />
<p>Train KAN with sparsity regularization</p>
<div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span># train the model
model.train(dataset, opt=&quot;LBFGS&quot;, steps=20, lamb=0.01, lamb_entropy=10.);
</pre></div>
</div>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span>train loss: 1.57e-01 | test loss: 1.31e-01 | reg: 2.05e+01 : 100%|██| 20/20 [00:18&lt;00:00, 1.06it/s]
</pre></div>
</div>
<p>Plot trained KAN</p>
<div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span>model.plot()
</pre></div>
</div>
<img alt=".ipynb_checkpoints/intro_files/intro_19_0.png" src=".ipynb_checkpoints/intro_files/intro_19_0.png" />
<p>Prune KAN and replot (keep the original shape)</p>
<div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span>model.prune()
model.plot(mask=True)
</pre></div>
</div>
<img alt=".ipynb_checkpoints/intro_files/intro_21_0.png" src=".ipynb_checkpoints/intro_files/intro_21_0.png" />
<p>Prune KAN and replot (get a smaller shape)</p>
<div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span>model = model.prune()
model(dataset[&#39;train_input&#39;])
model.plot()
</pre></div>
</div>
<img alt=".ipynb_checkpoints/intro_files/intro_23_0.png" src=".ipynb_checkpoints/intro_files/intro_23_0.png" />
<p>Continue training and replot</p>
<div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span>model.train(dataset, opt=&quot;LBFGS&quot;, steps=50);
</pre></div>
</div>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span>train loss: 4.74e-03 | test loss: 4.80e-03 | reg: 2.98e+00 : 100%|██| 50/50 [00:07&lt;00:00, 7.03it/s]
</pre></div>
</div>
<div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span>model.plot()
</pre></div>
</div>
<img alt=".ipynb_checkpoints/intro_files/intro_26_0.png" src=".ipynb_checkpoints/intro_files/intro_26_0.png" />
<p>Automatically or manually set activation functions to be symbolic</p>
<div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span>mode = &quot;auto&quot; # &quot;manual&quot;

if mode == &quot;manual&quot;:
# manual mode
model.fix_symbolic(0,0,0,&#39;sin&#39;);
model.fix_symbolic(0,1,0,&#39;x^2&#39;);
model.fix_symbolic(1,0,0,&#39;exp&#39;);
elif mode == &quot;auto&quot;:
# automatic mode
lib = [&#39;x&#39;,&#39;x^2&#39;,&#39;x^3&#39;,&#39;x^4&#39;,&#39;exp&#39;,&#39;log&#39;,&#39;sqrt&#39;,&#39;tanh&#39;,&#39;sin&#39;,&#39;abs&#39;]
model.auto_symbolic(lib=lib)
</pre></div>
</div>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">fixing</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span><span class="mi">0</span><span class="p">,</span><span class="mi">0</span><span class="p">)</span> <span class="k">with</span> <span class="n">sin</span><span class="p">,</span> <span class="n">r2</span><span class="o">=</span><span class="mf">0.999987252534279</span>
<span class="n">fixing</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span><span class="mi">1</span><span class="p">,</span><span class="mi">0</span><span class="p">)</span> <span class="k">with</span> <span class="n">x</span><span class="o">^</span><span class="mi">2</span><span class="p">,</span> <span class="n">r2</span><span class="o">=</span><span class="mf">0.9999996536741071</span>
<span class="n">fixing</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">0</span><span class="p">,</span><span class="mi">0</span><span class="p">)</span> <span class="k">with</span> <span class="n">exp</span><span class="p">,</span> <span class="n">r2</span><span class="o">=</span><span class="mf">0.9999988529417926</span>
</pre></div>
</div>
<p>Continue training to almost machine precision</p>
<div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span>model.train(dataset, opt=&quot;LBFGS&quot;, steps=50);
</pre></div>
</div>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span>train loss: 2.02e-10 | test loss: 1.13e-10 | reg: 2.98e+00 : 100%|██| 50/50 [00:02&lt;00:00, 22.59it/s]
</pre></div>
</div>
<p>Obtain the symbolic formula</p>
<div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span>model.symbolic_formula()[0][0]
</pre></div>
</div>
<div class="math notranslate nohighlight">
\[\displaystyle 1.0 e^{1.0 x_{2}^{2} + 1.0 \sin{\left(3.14 x_{1} \right)}}\]</div>
</section>
</section>


</div>
</div>
<footer>

<hr/>

<div role="contentinfo">
<p>&#169; Copyright 2024, Ziming Liu.</p>
</div>

Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.


</footer>
</div>
</div>
</section>
</div>
<script>
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>

</body>
</html>
Loading

0 comments on commit 84a2f36

Please sign in to comment.