KolmogorovArnold

Julia implementation of the Kolmogorov-Arnold network with custom gradients for fast training.

https://github.com/vpuri3/kolmogorovarnold.jl

Science Score: 64.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Committers with academic emails
    1 of 5 committers (20.0%) from academic institutions
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (6.7%) to scientific vocabulary

Keywords from Contributors

optim interactive optimisation unconstrained-optimisation unconstrained-optimization mesh interpretability ode sequences generic
Last synced: 6 months ago · JSON representation ·

Repository

Julia implementation of the Kolmogorov-Arnold network with custom gradients for fast training.

Basic Info
  • Host: GitHub
  • Owner: vpuri3
  • License: mit
  • Language: Julia
  • Default Branch: master
  • Homepage:
  • Size: 81.1 KB
Statistics
  • Stars: 76
  • Watchers: 6
  • Forks: 12
  • Open Issues: 3
  • Releases: 1
Created almost 2 years ago · Last pushed about 1 year ago
Metadata Files
Readme License Citation

README.md

KolmogorovArnold.jl

Build Status

Julia implementation of FourierKAN

Julia implementation of ChebyKAN

Julia implementation of the Kolmogorov-Arnold network for the Lux.jl framework. This implementation is based on efficient-kan and 'FastKAN' which resolve the performance issues with the original implementation. Key implementation details here are: - We fix our grid to be in [-1, 1] and normalize the the input to lie in that interval with tanh or NNlib.tanh_fast. - We use radial basis functions in place of the spline basis as the former is very efficient to evaluate.

```julia using Random, KolmogorovArnold rng = Random.default_rng()

indim, outdim, gridlen = 4, 4, 8 layer = KDense(indim, outdim, gridlen) p, st = Lux.setup(rng, layer)

x = rand32(rng, in_dim, 10) y = layer(x, p, st) ```

We compare the performance of different implementation of KAN with an MLP that has the same number of parameters (see examples/eg1.jl). ```julia using Lux, KolmogorovArnold using LuxCUDA

CUDA.allowscalar(false) device = Lux.gpu_device()

rng = Random.default_rng() Random.seed!(rng, 0)

x = rand32(rng, 1, 1000) |> device x₀ = rand32(rng, 1000, 1) |> device

define MLP, KANs

mlp = Chain( Dense(1, 128, tanh), Dense(128, 128, tanh), Dense(128, 1), ) # 16_897 parameters plus 0 states.

basisfunc = rbf # rbf, rswaf, iqf (radial basis funcs, reflection switch activation funcs, inverse quadratic funcs) normalizer = softsign # sigmoid(fast), tanh(_fast), softsign

kan1 = Chain( KDense( 1, 40, 10; usebaseact = true, basisfunc, normalizer), KDense(40, 40, 10; usebaseact = true, basisfunc, normalizer), KDense(40, 1, 10; usebaseact = true, basisfunc, normalizer), ) # 18490 parameters plus 30 states.

kan2 = Chain( KDense( 1, 40, 10; usebaseact = false, basisfunc, normalizer), KDense(40, 40, 10; usebaseact = false, basisfunc, normalizer), KDense(40, 1, 10; usebaseact = false, basisfunc, normalizer), ) # 16800 parameters plus 30 states.

kan3 = Chain( CDense( 1, 40, G), CDense(40, 40, G), CDense(40, 1, G), ) # 18_561 parameters plus 0 states.

kan4 = Chain( FDense( 1, 30, G), FDense(30, 30, G), FDense(30, 1, G), ) # 19_261 parameters plus 0 states.

set up experiment

pM, stM = Lux.setup(rng, mlp) pK1, stK1 = Lux.setup(rng, kan1) pK2, stK2 = Lux.setup(rng, kan2) pK3, stK3 = Lux.setup(rng, kan3) pK4, stK4 = Lux.setup(rng, kan4)

pM = ComponentArray(pM) |> device pK1 = ComponentArray(pK1) |> device pK2 = ComponentArray(pK2) |> device pK3 = ComponentArray(pK3) |> device pK4 = ComponentArray(pK4) |> device

stM, stK1, stK2, stK3, stK4 = device(stM), device(stK1), device(stK2), device(stK4), device(stK4)

Forward pass

@btime CUDA.@sync $mlp($x, $pM, $stM) # 31.611 μs (248 allocations: 5.45 KiB) @btime CUDA.@sync $kan1($x, $pK1, $stK1) # 125.790 μs (1034 allocations: 21.97 KiB) @btime CUDA.@sync $kan2($x, $pK2, $stK2) # 87.585 μs (1335 allocations: 13.95 KiB) @btime CUDA.@sync $kan3($x', $pK3, $stK3) # 210.785 μs (1335 allocations: 31.03 KiB) @btime CUDA.@sync $kan4($x', $pK4, $stK4) # 2.392 ms (1642 allocations: 34.56 KiB)

Backward pass

fmlp(p) = mlp(x, p, stM)[1] |> sum fkan1(p) = kan1(x, p, stK1)[1] |> sum fkan2(p) = kan2(x, p, stK2)[1] |> sum fkan3(p) = kan3(x₀, p, stK3)[1] |> sum f_kan4(p) = kan4(x₀, p, stK4)[1] |> sum

@btime CUDA.@sync Zygote.gradient($fmlp, $pM) # 268.074 μs (1971 allocations: 57.03 KiB) @btime CUDA.@sync Zygote.gradient($fkan1, $pK1) # 831.888 μs (5015 allocations: 123.25 KiB) @btime CUDA.@sync Zygote.gradient($fkan2, $pK2) # 658.578 μs (3314 allocations: 87.16 KiB) @btime CUDA.@sync Zygote.gradient($fkan3, $pK3) # 1.647 ms (7138 allocations: 180.45 KiB) @btime CUDA.@sync Zygote.gradient($f_kan4, $pK4) # 7.028 ms (8745 allocations: 199.42 KiB)

`` The performance of KAN with radial basis functions improves significantly withusebaseact = false`. Although KANs are currently significantly slower than an MLPs with the same number of parameters, the promise with this architecture is that a small KAN can potentially do the work of a much bigger MLP. More experiments need to be done to assess the validity of this claim.

This package will be actively developed for the time-being. Once a stable version is figured out, we can consider opening a PR on Lux.jl. Feel fre to open issues or create PRs in the meantime with features, comparisons, or performance improvements.

Custom gradients

Writing custom gradients for the activation function has led to substantial speedup in the backward pass (see examples/eg2.jl). ```julia N, G = 5000, 10

x = LinRange(-1, 1, N) |> Array |> device z = LinRange(-1, 1, G) |> Array |> device d = 2 / (G-1)

frbf(z) = rbf( x, z', d) |> sum frswaf(z) = rswaf(x, z', d) |> sum f_iqf(z) = iqf( x, z', d) |> sum

forward pass

@btime CUDA.@sync $frbf($z) # 55.566 μs (294 allocations: 7.78 KiB) @btime CUDA.@sync $frswaf($z) # 57.112 μs (294 allocations: 7.78 KiB) @btime CUDA.@sync $f_iqf($z) # 55.368 μs (294 allocations: 7.78 KiB)

backward pass

@btime CUDA.@sync Zygote.gradient($frbf , $z) # 188.456 μs (1045 allocations: 27.62 KiB) @btime CUDA.@sync Zygote.gradient($frswaf, $z) # 212.419 μs (1071 allocations: 28.30 KiB) @btime CUDA.@sync Zygote.gradient($f_iqf , $z) # 201.568 μs (1045 allocations: 27.62 KiB)

without custom gradients

RBF : 250.313 μs (1393 allocations: 35.95 KiB)

RSWAF: 282.864 μs (1389 allocations: 36.62 KiB)

IQF : 333.843 μs (1628 allocations: 42.70 KiB)

```

TODO

  • Grid update with linear least sq solve
  • devise good initialization schemes. RBF coefficients and base activation weights are currently initialized with WeightInitializers.glorot_uniform.
  • figure out what are good optimization strategies (choice of optimizer, learning rate decay, etc)

Owner

  • Name: Vedant Puri
  • Login: vpuri3
  • Kind: user
  • Location: Pittsburgh, PA
  • Company: Carnegie Mellon University

i write PDE solvers.

Citation (CITATION.bib)

@misc{KolmogorovArnold.jl,
	author  = {Vedant Puri <vedantpuri@gmail.com> and contributors},
	title   = {KolmogorovArnold.jl},
	url     = {https://github.com/vpuri3/KolmogorovArnold.jl},
	version = {v1.0.0-DEV},
	year    = {2024},
	month   = {5}
}

GitHub Events

Total
  • Create event: 2
  • Commit comment event: 4
  • Release event: 1
  • Issues event: 7
  • Watch event: 18
  • Issue comment event: 31
  • Push event: 10
  • Pull request review comment event: 14
  • Pull request review event: 14
  • Pull request event: 8
  • Fork event: 4
Last Year
  • Create event: 2
  • Commit comment event: 4
  • Release event: 1
  • Issues event: 7
  • Watch event: 18
  • Issue comment event: 31
  • Push event: 10
  • Pull request review comment event: 14
  • Pull request review event: 14
  • Pull request event: 8
  • Fork event: 4

Committers

Last synced: 10 months ago

All Time
  • Total Commits: 58
  • Total Committers: 5
  • Avg Commits per committer: 11.6
  • Development Distribution Score (DDS): 0.31
Past Year
  • Commits: 29
  • Committers: 4
  • Avg Commits per committer: 7.25
  • Development Distribution Score (DDS): 0.552
Top Committers
Name Email Commits
Vedant Puri v****i@g****m 40
Leon Armbruster l****9@g****m 13
dependabot[bot] 4****] 2
Avik Pal a****7@g****m 2
Martin Holters m****s@h****e 1
Committer Domains (Top 20 + Academic)

Issues and Pull Requests

Last synced: 6 months ago

All Time
  • Total issues: 4
  • Total pull requests: 7
  • Average time to close issues: 1 day
  • Average time to close pull requests: 3 days
  • Total issue authors: 3
  • Total pull request authors: 5
  • Average comments per issue: 4.25
  • Average comments per pull request: 2.0
  • Merged pull requests: 5
  • Bot issues: 0
  • Bot pull requests: 2
Past Year
  • Issues: 4
  • Pull requests: 5
  • Average time to close issues: 1 day
  • Average time to close pull requests: 4 days
  • Issue authors: 3
  • Pull request authors: 4
  • Average comments per issue: 4.25
  • Average comments per pull request: 2.8
  • Merged pull requests: 3
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • vpuri3 (2)
  • chooron (1)
  • avik-pal (1)
  • JuliaTagBot (1)
Pull Request Authors
  • dependabot[bot] (4)
  • vpuri3 (2)
  • martinholters (2)
  • armbrusl (2)
  • avik-pal (1)
Top Labels
Issue Labels
Pull Request Labels
dependencies (4)

Packages

  • Total packages: 1
  • Total downloads:
    • julia 3 total
  • Total dependent packages: 0
  • Total dependent repositories: 0
  • Total versions: 1
juliahub.com: KolmogorovArnold

Julia implementation of the Kolmogorov-Arnold network with custom gradients for fast training.

  • Versions: 1
  • Dependent Packages: 0
  • Dependent Repositories: 0
  • Downloads: 3 Total
Rankings
Downloads: 2.4%
Dependent repos count: 3.2%
Average: 7.3%
Dependent packages count: 16.3%
Last synced: 6 months ago

Dependencies

.github/workflows/CI.yml actions
  • actions/checkout v4 composite
  • julia-actions/cache v1 composite
  • julia-actions/julia-buildpkg v1 composite
  • julia-actions/julia-runtest v1 composite
  • julia-actions/setup-julia v1 composite
.github/workflows/CompatHelper.yml actions
.github/workflows/TagBot.yml actions
  • JuliaRegistries/TagBot v1 composite