https://github.com/google-deepmind/enn

https://github.com/google-deepmind/enn

Science Score: 36.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Committers with academic emails
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (15.0%) to scientific vocabulary

Keywords from Contributors

deep-neural-networks distributed research reinforcement-learning jax neuralgcm deep-reinforcement-learning transformer advertising distributed-training
Last synced: 6 months ago · JSON representation

Repository

Basic Info
  • Host: GitHub
  • Owner: google-deepmind
  • License: apache-2.0
  • Language: Python
  • Default Branch: master
  • Size: 1.41 MB
Statistics
  • Stars: 312
  • Watchers: 13
  • Forks: 61
  • Open Issues: 16
  • Releases: 0
Created almost 5 years ago · Last pushed about 1 year ago
Metadata Files
Readme Contributing License

README.md

Epistemic Neural Networks

A library for neural networks that know what they don't know.

For background information, please see the paper

Introduction

Conventional neural networks generate marginal predictions: given one input, they predict one label. If a neural network outputs probability 50:50 it remains unclear if that is because of genuine ambiguity in the input, or just because the neural network has insufficient training data. These two possibilities would be distinguished by joint predictions: given multiple inputs, predict multiple labels.

rabbit or duck

An epistemic neural network (ENN) makes predictions given a single input x, but also an epistemic index z. The ENN controls the index z and uses it to produce joint predictions over multiple inputs x_1,..,x_t which may be different from just the product of marginals.

nn diagrams

An ENN provides a general interface for thinking about uncertainty estimation in deep learning. Note that, all existing approaches to uncertainty modeling, such as Bayesian neural networks (BNNs), can be expressed as ENNs. However, there are ENN architectures that are not natural to express and BNNs. This library provides interfaces and tools for the design and training of ENNs.

Technical overview

The enn library provides a lightweight interface for ENNs implemented on top of JAX and Haiku. If you want to use our enn library, we highly recommend you start by familiarizing yourself with these libraries first.

We outline the key high-level interfaces for our code in base.py:

  • EpistemicNetwork: a convenient pairing of Haiku transformed + index sampler.
    • apply: haiku-style apply function taking params, x, z -> f_params(x,z)`
    • init: haiku-style init function taking key, x, z -> params_init
    • indexer: generates a sample from the reference index distribution taking key -> z.
  • LossFn: Given an ENN, parameters, and data: how to compute a loss.
    • Takes: enn, params, batch, key
    • Outputs: loss, metrics

We then use these high-level concepts to build and train ENNs.

Getting started

You can get started in our colab tutorial without installing anything on your machine.

Installation

We have tested ENN on Python 3.7. To install the dependencies:

  1. Optional: We recommend using a Python virtual environment to manage your dependencies, so as not to clobber your system installation:

    bash python3 -m venv enn source enn/bin/activate pip install --upgrade pip setuptools

  2. Install ENN directly from github:

    bash pip install git+https://github.com/deepmind/enn

  3. Test that you can load ENN by training a simple ensemble ENN.

    ```python from enn.loggers import TerminalLogger

    from enn import losses from enn import networks from enn import supervised from enn.supervised import regression_data import optax

    A small dummy dataset

    dataset = regressiondata.makedataset()

    Logger

    logger = TerminalLogger('supervised_regression')

    ENN

    enn = networks.MLPEnsembleMatchedPrior( outputsizes=[50, 50, 1], numensemble=10, )

    Loss

    lossfn = losses.averagesingleindexloss( singleloss=losses.L2LossWithBootstrap(), numindex_samples=10 )

    Optimizer

    optimizer = optax.adam(1e-3)

    Train the experiment

    experiment = supervised.Experiment( enn, lossfn, optimizer, dataset, seed=0, logger=logger) experiment.train(FLAGS.numbatch) ```

  4. Optional: run the tests by executing ./test.sh from ENN root directory.

Epinet

One of the key contributions of our paper is the epinet: a new ENN architecture that can supplement any conventional NN and be trained to estimate uncertainty.

An epinet is a neural network with privileged access to inputs and outputs of activation units in the base network. A subset of these inputs and outputs, denoted by $\phi\zeta(x)$, are taken as input to the epinet along with an epistemic index $z$. For epinet parameters $\eta$, the epinet outputs $\sigma\eta(\phi_\zeta(x), z)$. To produce an ENN, the output of the epinet is added to that of the base network, though with a "stop gradient" written $[[\cdot]]$:

$$ f\theta(x, z) = \mu\zeta(x) + \sigma\eta([[\phi\zeta(x)]], z). $$

We can visualize this network architecture:

epinet diagram

As part of our release include an epinet colab that loads in a pre-trained base network and epinet on ImageNet. The core network logic for epinet is available in networks/epinet.

Citing

If you use ENN in your work, please cite the accompanying paper:

bibtex @article{osband2022epistemic, title={Epistemic neural networks}, author={Osband, Ian and Wen, Zheng and Asghari, Seyed Mohammad and Dwaracherla, Vikranth and Ibrahimi, Morteza and Lu, Xiuyuan and Van Roy, Benjamin}, journal={arXiv preprint arXiv:2107.08924}, year={2022} }

Owner

  • Name: Google DeepMind
  • Login: google-deepmind
  • Kind: organization

GitHub Events

Total
  • Watch event: 18
  • Push event: 6
  • Pull request event: 1
  • Fork event: 5
  • Create event: 1
Last Year
  • Watch event: 18
  • Push event: 6
  • Pull request event: 1
  • Fork event: 5
  • Create event: 1

Committers

Last synced: 10 months ago

All Time
  • Total Commits: 200
  • Total Committers: 16
  • Avg Commits per committer: 12.5
  • Development Distribution Score (DDS): 0.58
Past Year
  • Commits: 1
  • Committers: 1
  • Avg Commits per committer: 1.0
  • Development Distribution Score (DDS): 0.0
Top Committers
Name Email Commits
Mohammad Asghari s****i@g****m 84
Ian Osband i****d@g****m 37
DeepMind n****y@g****m 20
Lucy Lu l****u@g****m 15
Morteza Ibrahimi m****i@g****m 12
Mehdi Jafarnia j****a@g****m 10
Jake VanderPlas v****s@g****m 5
Vikranth Dwaracherla v****d@g****m 4
Peter Hawkins p****s@g****m 4
Roman Novak r****n@g****m 2
Rebecca Chen r****n@g****m 2
Yilei Yang y****g@g****m 1
Sergei Lebedev s****v@g****m 1
Saran Tunyasuvunakool s****a@g****m 1
Zheng Wen z****n@g****m 1
Grace Lam g****m@g****m 1
Committer Domains (Top 20 + Academic)

Issues and Pull Requests

Last synced: 10 months ago

All Time
  • Total issues: 10
  • Total pull requests: 19
  • Average time to close issues: about 2 months
  • Average time to close pull requests: 11 days
  • Total issue authors: 8
  • Total pull request authors: 5
  • Average comments per issue: 0.1
  • Average comments per pull request: 0.05
  • Merged pull requests: 4
  • Bot issues: 0
  • Bot pull requests: 15
Past Year
  • Issues: 0
  • Pull requests: 2
  • Average time to close issues: N/A
  • Average time to close pull requests: 2 days
  • Issue authors: 0
  • Pull request authors: 1
  • Average comments per issue: 0
  • Average comments per pull request: 0.0
  • Merged pull requests: 1
  • Bot issues: 0
  • Bot pull requests: 2
Top Authors
Issue Authors
  • orbulon (3)
  • aaprasad (1)
  • hstojic (1)
  • rudibhargava (1)
  • MaxGhenis (1)
  • monsebo (1)
  • themantalope (1)
  • fazaghifari (1)
Pull Request Authors
  • copybara-service[bot] (21)
  • MaxGhenis (1)
  • amrzv (1)
  • shyams2 (1)
  • anukaal (1)
Top Labels
Issue Labels
Pull Request Labels

Dependencies

setup.py pypi
  • absl-py *
  • chex *
  • dataclasses *
  • dm-acme ==0.4.0
  • dm-haiku *
  • jax *
  • jaxlib *
  • jaxline *
  • matplotlib *
  • neural-tangents *
  • numpy *
  • optax *
  • pandas *
  • plotnine *
  • rlax *
  • scikit-image *
  • scikit-learn *
  • scipy *
  • six *
  • tensorflow ==2.8.0
  • tensorflow-datasets ==4.4.0
  • tensorflow_probability ==0.15.0
  • termcolor *
  • typing-extensions *
.github/workflows/ci.yml actions
  • actions/checkout v2 composite
  • actions/setup-python v1 composite