Science Score: 54.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Committers with academic emails
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (10.3%) to scientific vocabulary
Last synced: 7 months ago · JSON representation ·

Repository

Basic Info
  • Host: GitHub
  • Owner: wearepal
  • License: apache-2.0
  • Language: Jupyter Notebook
  • Default Branch: main
  • Size: 3.68 MB
Statistics
  • Stars: 3
  • Watchers: 0
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Created 10 months ago · Last pushed 10 months ago
Metadata Files
Readme License Citation

README.md

performativeGYM

performativeGYM is a library for simulating performative prediction, a machine learning setting that was introduced by Perdomo et al. (2020). In performative prediction, the act of making predictions is affecting the world, such that the distribution of data encountered doesn’t match the training distribution anymore. An example is classifier used by a bank to make lending decisions, which has the effect that bank customers try to “game” the classifier in order to improve their chances of success.

The code in this project is split into two parts: the library itself, in the directory performative_gym/, which contains important definitions and implementations of proposed methods, and on the other hand, the examples/ directory, which contains the implementations of concrete performative prediction scenarios.

The library

The library is written in JAX. It contains implementations of many algorithms that have been proposed in the literature:

All these methods are implemented as subclasses of the following abstract base class:

```python class Optimizer: def init( self, params: Array, lr: float, lossfn: Callable[[Array, Array, Array], Array], projfn: Callable[[Array], Array] = (lambda params: params), ): self.currentparams = params self.lr = lr self.lossfn = lossfn self.projfn = projfn self.paramshistory = [params] self.i = 0

@abstractmethod
def step(self, params: Array, x: Array, y: Array) -> Array:
    pass

```

In every call to the .step() method, the methods need to update the given parameters, for the features x and labels y, and need to return the new parameters. The methods are given an initial set of parameters, a loss function, and a projection function (which projects parameter values into the allowed range of parameter values). These three things are specific to the concrete setting in which the experiment is run. Some methods need even more information than that; for example, many need the distribution shift function as a differentiable function.

The examples

In the examples directory, several concrete scenarios are implemented, which can be run with any of the methods defined in the library. As mentioned above, each scenario needs to define, at minimum, the initial parameters, the loss function and the projection function. In addition, many methods also need a differentiable distribution shift function.

Here is a minimal example where the model is a simple linear model with a 1D weight vector and the data is sampled from a Gaussian which has a linear dependency on the weight vector:

```python from dataclasses import dataclass

import jax import jax.numpy as jnp from jax import Array

from performativegym import RGD from performativegym.utils import initialize_params

@dataclass class Minimal: A0: float = 5 A1: float = 1 STD: float = 1 n: int = 10000 iterations: int = 30 seed: int = 0 lr: float = 0.1

def loss_fn(self, params: Array, x: Array, y: None) -> Array:
    # Simple linear loss function
    return params * x

def proj_fn(self, params: Array) -> Array:
    return jnp.clip(params, -1.0, 1.0)

def shift_data_distribution(self, params: Array, n: int) -> Array:
    # Normal distribution with mean A1 * params + A0 and std STD
    z = jax.random.normal(jax.random.PRNGKey(self.seed), (n,))
    return jnp.expand_dims((self.A1 * params + self.A0) + z * self.STD, axis=1)

def initial_params(self):
    return 0.85 + initialize_params((1,), self.seed) * 0.1

def train(self) -> RGD:
    params = self.initial_params()
    method = RGD(params, lr=self.lr, loss_fn=self.loss_fn, proj_fn=self.proj_fn)

    for i in range(self.iterations):
        z = self.shift_data_distribution(params, self.n)
        # Perform gradient descent step
        params = method.step(params, x=z, y=None)
        # Compute current loss
        current_loss = jnp.mean(self.loss_fn(params, x=z, y=None))
        print(f"Iteration {i + 1}/{self.iterations}, Loss: {current_loss:.4f}")
    return method

```

It is not necessary to use dataclasses for this, but it is convenient. The model and the data can be anything, as long as the loss function, the projection function and the shift function can handle them.

The existing examples are:

  • credit.py: GiveMeSomeCredit
  • linear.py: 1D Gaussian with linear dependency on the model weights
  • nonlinear.py: 1D Gaussian with non-linear dependency on the model weights
  • mixture.py: a mixture of Gaussians
  • pricing.py: multivariate Gaussian
  • cosine.py: 1D Gaussian with a cosine loss function

If you want to run these examples, see below for the instructions.

Usage

Install dependencies

With uv: sh uv sync

With pip: sh pip install -e .

Run examples

With uv: sh uv run python examples/linear.py

With pip: sh python examples/linear.py

If you supply the --help flag, a help message is printed with information about the available commandline arguments.

License

his project is licensed under the Apache License 2.0. See the LICENSE file for details.

Owner

  • Name: Predictive Analytics Lab
  • Login: wearepal
  • Kind: organization
  • Location: University of Sussex, Brighton, UK

Citation (CITATION.bib)

@misc{sanguino2025decoupled,
      title={The Decoupled Risk Landscape in Performative Prediction}, 
      author={Javier Sanguino and Thomas Kehrenberg and Jose A. Lozano and Novi Quadrianto},
      year={2025},
      eprint={2506.09044},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2506.09044}, 
}

GitHub Events

Total
  • Watch event: 1
  • Push event: 12
  • Public event: 1
  • Pull request review event: 1
  • Pull request review comment event: 1
Last Year
  • Watch event: 1
  • Push event: 12
  • Public event: 1
  • Pull request review event: 1
  • Pull request review comment event: 1

Committers

Last synced: 9 months ago

All Time
  • Total Commits: 19
  • Total Committers: 3
  • Avg Commits per committer: 6.333
  • Development Distribution Score (DDS): 0.105
Past Year
  • Commits: 19
  • Committers: 3
  • Avg Commits per committer: 6.333
  • Development Distribution Score (DDS): 0.105
Top Committers
Name Email Commits
Thomas M Kehrenberg t****8@p****t 17
fjsanguino j****8@g****m 1
Ng Yeat Jeng n****g@o****m 1
Committer Domains (Top 20 + Academic)

Issues and Pull Requests

Last synced: 9 months ago

All Time
  • Total issues: 0
  • Total pull requests: 1
  • Average time to close issues: N/A
  • Average time to close pull requests: about 2 hours
  • Total issue authors: 0
  • Total pull request authors: 1
  • Average comments per issue: 0
  • Average comments per pull request: 2.0
  • Merged pull requests: 1
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 0
  • Pull requests: 1
  • Average time to close issues: N/A
  • Average time to close pull requests: about 2 hours
  • Issue authors: 0
  • Pull request authors: 1
  • Average comments per issue: 0
  • Average comments per pull request: 2.0
  • Merged pull requests: 1
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
Pull Request Authors
  • yjng (1)
Top Labels
Issue Labels
Pull Request Labels

Dependencies

pyproject.toml pypi
  • jax >=0.6.1
  • neural-tangents >=0.5.0
  • optax >=0.2.5
  • pandas >=2.3.0
  • scikit-learn >=1.7.0
  • tqdm >=4.67.1
  • tyro >=0.9.24
uv.lock pypi
  • absl-py 2.3.0
  • annotated-types 0.7.0
  • astunparse 1.6.3
  • attrs 25.3.0
  • certifi 2025.4.26
  • charset-normalizer 3.4.2
  • chex 0.1.89
  • click 8.2.1
  • colorama 0.4.6
  • dm-tree 0.1.9
  • docstring-parser 0.16
  • flatbuffers 25.2.10
  • frozendict 2.4.6
  • gast 0.6.0
  • gitdb 4.0.12
  • gitpython 3.1.44
  • google-pasta 0.2.0
  • grpcio 1.73.0
  • h5py 3.14.0
  • idna 3.10
  • jax 0.6.1
  • jaxlib 0.6.1
  • joblib 1.5.1
  • keras 3.10.0
  • libclang 18.1.1
  • markdown 3.8
  • markdown-it-py 3.0.0
  • markupsafe 3.0.2
  • mdurl 0.1.2
  • ml-dtypes 0.5.1
  • namex 0.1.0
  • neural-tangents 0.5.0
  • neural-tangents 0.6.5
  • numpy 2.1.3
  • numpy 2.3.0
  • opt-einsum 3.4.0
  • optax 0.2.5
  • optree 0.16.0
  • packaging 25.0
  • pandas 2.3.0
  • pandas-stubs 2.2.3.250527
  • performative-gym 0.1.0
  • platformdirs 4.3.8
  • protobuf 5.29.5
  • protobuf 6.31.1
  • psutil 7.0.0
  • pydantic 2.11.5
  • pydantic-core 2.33.2
  • pygments 2.19.1
  • python-dateutil 2.9.0.post0
  • python-type-stubs 0.1.6.dev0
  • pytz 2025.2
  • pyyaml 6.0.2
  • requests 2.32.4
  • rich 14.0.0
  • scikit-learn 1.7.0
  • scipy 1.15.3
  • sentry-sdk 2.29.1
  • setproctitle 1.3.6
  • setuptools 80.9.0
  • shtab 1.7.2
  • six 1.17.0
  • smmap 5.0.2
  • tensorboard 2.19.0
  • tensorboard-data-server 0.7.2
  • tensorflow 2.19.0
  • tensorflow-io-gcs-filesystem 0.37.1
  • termcolor 3.1.0
  • tf2jax 0.3.7
  • threadpoolctl 3.6.0
  • toolz 1.0.0
  • tqdm 4.67.1
  • typeguard 4.4.3
  • types-pytz 2025.2.0.20250516
  • typing-extensions 4.14.0
  • typing-inspection 0.4.1
  • tyro 0.9.24
  • tzdata 2025.2
  • urllib3 2.4.0
  • wandb 0.20.1
  • werkzeug 3.1.3
  • wheel 0.45.1
  • wrapt 1.17.2