https://github.com/aaltoml/sfr-experiments

Code accompanying ICLR 2024 paper "Function-space Parameterization of Neural Networks for Sequential Learning"

https://github.com/aaltoml/sfr-experiments

Science Score: 23.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (10.0%) to scientific vocabulary

Keywords

bayesian-deep-learning bayesian-inference bayesian-neural-networks deep-learning gaussian-processes laplace-approximation pytorch
Last synced: 6 months ago · JSON representation

Repository

Code accompanying ICLR 2024 paper "Function-space Parameterization of Neural Networks for Sequential Learning"

Basic Info
Statistics
  • Stars: 0
  • Watchers: 4
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Topics
bayesian-deep-learning bayesian-inference bayesian-neural-networks deep-learning gaussian-processes laplace-approximation pytorch
Created about 2 years ago · Last pushed almost 2 years ago
Metadata Files
Readme License

README.md

Function-space Parameterization of Neural Networks for Sequential Learning

Code accompanying ICLR 2024 submission Function-space Parameterization of Neural Networks for Sequential Learning. This repository contains code for reproducing the experiments in the ICLR 2024 paper. Please see this repo for a clean and minimal implementation of Sparse Function-space Representation of Neural Networks (SFR). We recommend using the clean and minimal repo.

Function-space Parameterization of Neural Networks for Sequential Learning
Aidan Scannell*, Riccardo Mereu*, Paul Chang, Ella Tamir, Joni Pajarinen, Arno Solin
International Conference on Learning Representations (ICLR 2024)
Sparse Function-space Representation of Neural Networks
Aidan Scannell*, Riccardo Mereu*, Paul Chang, Ella Tamir, Joni Pajarinen, Arno Solin
ICML 2023 Workshop on Duality Principles for Modern Machine Learning

Install

Install using virtual environment

Make a virtual environment: sh python -m venv .venv Activate it with: sh source .venv/bin/activate Install the dependencies with: sh python -m pip install --upgrade pip pip install laplace-torch==0.1a2 pip install -e ".[experiments]" We install laplace-torch separately due to version conflicts with backpacpk-for-pytorch. Note that laplace-torch is only used for running the baselines.

Install using pip

Alternatively, manually install the dependencies with: sh pip install laplace-torch==0.1a2 pip install -r requirements.txt

Reproducing experiments

See experiments for details on how to reproduce the results in the paper. This includes code for generating the tables and figures.

Useage

See the notebooks/README.md for how to use our code for both regression and classification.

Example

Here's a short example: ```python import src import torch

torch.setdefaultdtype(torch.float64)

def func(x, noise=True): return torch.sin(x * 5) / x + torch.cos(x * 10)

Toy data set

Xtrain = torch.rand((100, 1)) * 2 Ytrain = func(Xtrain, noise=True) data = (Xtrain, Y_train)

Training config

width = 64 numepochs = 1000 batchsize = 16 learningrate = 1e-3 delta = 0.00005 # prior precision dataloader = torch.utils.data.DataLoader( torch.utils.data.TensorDataset(*data), batchsize=batchsize )

Create a neural network

network = torch.nn.Sequential( torch.nn.Linear(1, width), torch.nn.Tanh(), torch.nn.Linear(width, width), torch.nn.Tanh(), torch.nn.Linear(width, 1), )

Instantiate SFR (handles NN training/prediction as they're coupled via the prior/likelihood)

sfr = src.sfr.SFR( network=network, prior=src.priors.Gaussian(params=network.parameters, delta=delta), likelihood=src.likelihoods.Gaussian(sigmanoise=2), outputdim=1, numinducing=32, dualbatch_size=None, # this reduces the memory required for computing dual parameters jitter=1e-4, )

sfr.train() optimizer = torch.optim.Adam([{"params": sfr.parameters()}], lr=learningrate) for epochidx in range(numepochs): for batchidx, batch in enumerate(dataloader): x, y = batch loss = sfr.loss(x, y) optimizer.zerograd() loss.backward() optimizer.step()

sfr.set_data(data) # This builds the dual parameters

Make predictions in function space

Xtest = torch.linspace(-0.7, 3.5, 300, dtype=torch.float64).reshape(-1, 1) fmean, fvar = sfr.predictf(X_test)

Make predictions in output space

ymean, yvar = sfr.predict(X_test) ```

Citation

Please consider citing our ICLR 2024 paper. bibtex @inproceedings{scannellFunction2024, title = {Function-space Prameterization of Neural Networks for Sequential Learning}, booktitle = {Proceedings of The Twelth International Conference on Learning Representations (ICLR 2024)}, author = {Aidan Scannell and Riccardo Mereu and Paul Chang and Ella Tami and Joni Pajarinen and Arno Solin}, year = {2024}, month = {5}, }

Owner

  • Name: AaltoML
  • Login: AaltoML
  • Kind: organization
  • Location: Finland

Machine learning group at Aalto University lead by Prof. Solin

GitHub Events

Total
Last Year

Dependencies

experiments/cl/baselines/S-FSVI/environment.yml pypi
  • chex ==0.0.7
  • dm-haiku ==0.0.4
  • dm-sonnet ==2.0.0
  • dm-tree ==0.1.6
  • gpflow ==2.2.1
  • gpustat ==0.6.0
  • jax ==0.2.11
  • jaxlib ==0.1.64
  • joypy *
  • keras ==2.6.0
  • latex ==0.7.0
  • neural-tangents ==0.3.6
  • optax ==0.0.2
  • pandas ==1.3.5
  • pyreadr *
  • pytest *
  • retry ==0.9.2
  • scikit-learn ==0.21.3
  • scipy ==1.6.2
  • seqtools ==1.1.0
  • tensorflow-addons ==0.13.0
  • tensorflow-datasets ==4.2.0
  • tensorflow-probability ==0.13.0
  • tfp-nightly ==0.12.0.dev20201117
  • uncertainty-metrics ==0.0.81
  • uncertainty_metrics *
experiments/cl/baselines/S-FSVI/setup.py pypi
requirements.txt pypi
  • Babel ==2.14.0
  • GitPython ==3.1.42
  • Jinja2 ==3.1.3
  • Markdown ==3.5.2
  • MarkupSafe ==2.1.5
  • PyOpenGL ==3.1.7
  • PyYAML ==6.0.1
  • Pygments ==2.17.2
  • QtPy ==2.4.1
  • Send2Trash ==1.8.2
  • Werkzeug ==3.0.1
  • absl-py ==2.1.0
  • antlr4-python3-runtime ==4.9.3
  • anyio ==4.3.0
  • appdirs ==1.4.4
  • argon2-cffi ==23.1.0
  • argon2-cffi-bindings ==21.2.0
  • arrow ==1.3.0
  • asdfghjkl ==0.1a2
  • asttokens ==2.4.1
  • async-lru ==2.0.4
  • attrs ==23.2.0
  • ax-platform ==0.3.3
  • backpack-for-pytorch ==1.6.0
  • beautifulsoup4 ==4.12.3
  • bleach ==6.1.0
  • botorch ==0.8.5
  • certifi ==2024.2.2
  • cffi ==1.16.0
  • charset-normalizer ==3.3.2
  • click ==8.1.7
  • cloudpickle ==3.0.0
  • cmake ==3.28.3
  • comm ==0.2.1
  • cycler ==0.12.1
  • debugpy ==1.8.1
  • decorator ==4.4.2
  • defusedxml ==0.7.1
  • dm-control ==1.0.11
  • dm-env ==1.6
  • dm-tree ==0.1.8
  • docker-pycreds ==0.4.0
  • einops ==0.7.0
  • exceptiongroup ==1.2.0
  • executing ==2.0.1
  • fastjsonschema ==2.19.1
  • filelock ==3.13.1
  • fonttools ==4.49.0
  • fqdn ==1.5.1
  • fsspec ==2024.2.0
  • gitdb ==4.0.11
  • glfw ==2.6.5
  • gpytorch ==1.10
  • grpcio ==1.60.1
  • gym ==0.26.2
  • gym-notices ==0.0.8
  • h11 ==0.14.0
  • httpcore ==1.0.3
  • httpx ==0.26.0
  • hydra-core ==1.3.2
  • hydra-submitit-launcher ==1.2.0
  • idna ==3.6
  • imageio ==2.34.0
  • imageio-ffmpeg ==0.4.9
  • importlib-metadata ==7.0.1
  • ipykernel ==6.29.2
  • ipython ==8.18.1
  • ipywidgets ==8.1.2
  • isoduration ==20.11.0
  • jedi ==0.19.1
  • joblib ==1.3.2
  • json5 ==0.9.17
  • jsonpointer ==2.4
  • jsonschema ==4.21.1
  • jsonschema-specifications ==2023.12.1
  • jupyter ==1.0.0
  • jupyter-console ==6.6.3
  • jupyter-events ==0.9.0
  • jupyter-lsp ==2.2.2
  • jupyter_client ==8.6.0
  • jupyter_core ==5.7.1
  • jupyter_server ==2.12.5
  • jupyter_server_terminals ==0.5.2
  • jupyterlab ==4.1.2
  • jupyterlab_pygments ==0.3.0
  • jupyterlab_server ==2.25.3
  • jupyterlab_widgets ==3.0.10
  • kiwisolver ==1.4.5
  • labmaze ==1.0.6
  • laplace-torch ==0.1a2
  • linear-operator ==0.4.0
  • lit ==17.0.6
  • lxml ==5.1.0
  • matplotlib ==3.5.1
  • matplotlib-inline ==0.1.6
  • mistune ==3.0.2
  • moviepy ==1.0.3
  • mpmath ==1.3.0
  • mujoco ==2.3.3
  • multipledispatch ==1.0.0
  • nbclient ==0.9.0
  • nbconvert ==7.16.1
  • nbformat ==5.9.2
  • nest-asyncio ==1.6.0
  • netcal ==1.3.5
  • networkx ==3.2.1
  • notebook ==7.1.0
  • notebook_shim ==0.2.4
  • numpy ==1.24.2
  • nvidia-cublas-cu11 ==11.10.3.66
  • nvidia-cublas-cu12 ==12.1.3.1
  • nvidia-cuda-cupti-cu11 ==11.7.101
  • nvidia-cuda-cupti-cu12 ==12.1.105
  • nvidia-cuda-nvrtc-cu11 ==11.7.99
  • nvidia-cuda-nvrtc-cu12 ==12.1.105
  • nvidia-cuda-runtime-cu11 ==11.7.99
  • nvidia-cuda-runtime-cu12 ==12.1.105
  • nvidia-cudnn-cu11 ==8.5.0.96
  • nvidia-cudnn-cu12 ==8.9.2.26
  • nvidia-cufft-cu11 ==10.9.0.58
  • nvidia-cufft-cu12 ==11.0.2.54
  • nvidia-curand-cu11 ==10.2.10.91
  • nvidia-curand-cu12 ==10.3.2.106
  • nvidia-cusolver-cu11 ==11.4.0.1
  • nvidia-cusolver-cu12 ==11.4.5.107
  • nvidia-cusparse-cu11 ==11.7.4.91
  • nvidia-cusparse-cu12 ==12.1.0.106
  • nvidia-nccl-cu11 ==2.14.3
  • nvidia-nccl-cu12 ==2.19.3
  • nvidia-nvjitlink-cu12 ==12.3.101
  • nvidia-nvtx-cu11 ==11.7.91
  • nvidia-nvtx-cu12 ==12.1.105
  • omegaconf ==2.3.0
  • opencv-python ==4.7.0.72
  • opt-einsum ==3.3.0
  • overrides ==7.7.0
  • packaging ==23.2
  • pandas ==2.2.0
  • pandocfilters ==1.5.1
  • parso ==0.8.3
  • pexpect ==4.9.0
  • pillow ==10.2.0
  • platformdirs ==4.2.0
  • plotly ==5.19.0
  • proglog ==0.1.10
  • prometheus_client ==0.20.0
  • prompt-toolkit ==3.0.43
  • protobuf ==4.25.3
  • psutil ==5.9.8
  • ptyprocess ==0.7.0
  • pure-eval ==0.2.2
  • pycparser ==2.21
  • pygame ==2.1.0
  • pyparsing ==3.1.1
  • pyro-api ==0.1.2
  • pyro-ppl ==1.9.0
  • python-dateutil ==2.8.2
  • python-json-logger ==2.0.7
  • pytz ==2024.1
  • pyzmq ==25.1.2
  • qtconsole ==5.5.1
  • referencing ==0.33.0
  • requests ==2.31.0
  • rfc3339-validator ==0.1.4
  • rfc3986-validator ==0.1.1
  • rpds-py ==0.18.0
  • scikit-learn ==1.4.1.post1
  • scipy ==1.12.0
  • seaborn ==0.13.2
  • sentry-sdk ==1.40.5
  • setproctitle ==1.3.3
  • six ==1.16.0
  • smmap ==5.0.1
  • sniffio ==1.3.0
  • soupsieve ==2.5
  • stack-data ==0.6.3
  • submitit ==1.5.1
  • sympy ==1.12
  • tenacity ==8.2.3
  • tensorboard ==2.16.2
  • tensorboard-data-server ==0.7.2
  • terminado ==0.18.0
  • threadpoolctl ==3.3.0
  • tikzplotlib ==0.9.8
  • tinycss2 ==1.2.1
  • tomli ==2.0.1
  • torch ==2.0.0
  • torchaudio ==2.0.1
  • torchtyping ==0.1.4
  • torchvision ==0.15.1
  • tornado ==6.4
  • tqdm ==4.66.2
  • traitlets ==5.14.1
  • triton ==2.0.0
  • typeguard ==2.13.3
  • types-python-dateutil ==2.8.19.20240106
  • typing_extensions ==4.9.0
  • tzdata ==2024.1
  • unfoldNd ==0.2.1
  • uri-template ==1.3.0
  • urllib3 ==2.2.1
  • wandb ==0.16.3
  • wcwidth ==0.2.13
  • webcolors ==1.13
  • webencodings ==0.5.1
  • websocket-client ==1.7.0
  • widgetsnbextension ==4.0.10
  • zipp ==3.17.0
setup.py pypi
  • ax-platform ==0.3.3
  • laplace-torch *
  • matplotlib ==3.5.1
  • numpy ==1.24.2
  • torch ==2.0.0
  • torchtyping ==0.1.4
  • torchvision ==0.15.1