jax-privacy
Algorithms for Privacy-Preserving Machine Learning in JAX
Science Score: 54.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
✓Academic publication links
Links to: arxiv.org -
○Committers with academic emails
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (15.6%) to scientific vocabulary
Keywords from Contributors
Repository
Algorithms for Privacy-Preserving Machine Learning in JAX
Basic Info
Statistics
- Stars: 96
- Watchers: 9
- Forks: 14
- Open Issues: 2
- Releases: 1
Metadata Files
README.md
JAX-Privacy: Algorithms for Privacy-Preserving Machine Learning in JAX
| Docs | Library | Installation | Reproducing Results | Citing | Contact
This repository contains:
- A production-focused API for differentially-private (DP) training of ML models in JAX and Keras.
- A library of core components for implementing differentially private machine learning algorithms in JAX.
- A JAX-based machine learning DP pipeline using components from the library to experiment with image classification models.
This code is open-sourced with the main objective of transparency and reproducibility for research purposes, and includes production-focused APIs for differentially private machine learning. Some rough edges should be expected, especially in the research components.
New: Production-Focused JAX Privacy Library
We are excited to introduce a more production-focused JAX Privacy API designed to simplify the development of differentially-private (DP) machine learning models, including Large Language Models (LLMs).
Key Features:
- Algorithm Support: Currently supports the DP-SGD (Differentially Private Stochastic Gradient Descent) algorithm. We are actively working on incorporating more DP algorithms (e.g. DP-FTRL) in the near future.
- Framework Integration: The library provides APIs tailored for different
JAX-based development experiences:
- Keras: A high-level API, excellent for common tasks like fine-tuning LLMs. See Keras API simple example and Gemma fine-tuning notebook to get started.
- Flax Linen: Offers greater flexibility for custom model architectures and training loops, at the cost of some additional boilerplate. See MNIST notebook to get started.
- Raw JAX: Provides the most low-level control. Recommended for researchers who want to test out new ideas, people who want to use a framework not listed above (e.g. equinox), or people who want a more numpy-like experience. <!-- TODO - b/398715962: add "External Contributions & Design" section, link to readthedocs -->
This new JAX Privacy API aims to provide a more streamlined and robust experience for building DP ML models, complementing the existing research-focused components.
We believe this new API will significantly lower the barrier to implementing DP in your machine learning projects.
Installation
Note: to ensure that your installation is compatible with your local accelerators such as a GPU, we recommend to first follow the corresponding instructions to install JAX.
Option 1: Static Installation
This option is preferred for the purpose of re-using functionalities of our library without modifying them. The library package can be installed by running the following command-line:
pip install git+https://github.com/google-deepmind/jax_privacy
This will not install the training pipeline.
Option 2: Local Installation
This option is preferred to either build on top of our codebase, or to reproduce our results using the training pipeline.
- The first step is to clone the repository:
git clone https://github.com/google-deepmind/jax_privacy
- Then the code can be installed. We recommend local installation so modifications to the code are reflected in imports of the package:
cd jax_privacy
pip install -e .
Reproducing Results
Unlocking High-Accuracy Differentially Private Image Classification through Scale
- Instructions: experiments/image_classification.
- arXiv link: https://arxiv.org/abs/2204.13650.
- Bibtex reference: link.
Unlocking Accuracy and Fairness in Differentially Private Image Classification
- Instructions: experiments/image_classification.
- arXiv link: https://arxiv.org/abs/2308.10888.
- Bibtex reference: link.
How to Cite This Repository
If you use code from this repository, please cite the following reference:
@software{jax-privacy2022github,
author = {Balle, Borja and Berrada, Leonard and Charles, Zachary and
Choquette-Choo, Christopher A and De, Soham and Doroshenko, Vadym and Dvijotham,
Dj and Galen, Andrew and Ganesh, Arun and Ghalebikesabi, Sahra and Hayes, Jamie
and Kairouz, Peter and McKenna, Ryan and McMahan, Brendan and Pappu, Aneesh and
Ponomareva, Natalia and Pravilov, Mikhail and Rush, Keith and Smith, Samuel L
and Stanforth, Robert},
title = {{JAX}-{P}rivacy: Algorithms for Privacy-Preserving Machine Learning in JAX},
url = {http://github.com/google-deepmind/jax_privacy},
version = {0.4.0},
year = {2025},
}
Contact
If you have any questions or feedback, you can contact us via email: jax-privacy-open-source@google.com.
Acknowledgements
License
All code is made available under the Apache 2.0 License. Model parameters are made available under the Creative Commons Attribution 4.0 International (CC BY 4.0) License.
See https://creativecommons.org/licenses/by/4.0/legalcode for more details.
Disclaimer
This is not an official Google product.
Owner
- Name: Google DeepMind
- Login: google-deepmind
- Kind: organization
- Website: https://www.deepmind.com/
- Repositories: 245
- Profile: https://github.com/google-deepmind
Citation (CITATION.cff)
cff-version: 1.2.0 message: "If you use this software, please cite it as below." authors: - family-names: "Balle" given-names: "Borja" - family-names: "Berrada" given-names: "Leonard" - family-names: "Charles" given-names: "Zachary" - family-names: "Choquette-Choo" given-names: "Christopher A" - family-names: "De" given-names: "Soham" - family-names: "Doroshenko" given-names: "Vadym" - family-names: "Dvijotham" given-names: "Dj" - family-names: "Galen" given-names: "Andrew" - family-names: "Ganesh" given-names: "Arun" - family-names: "Ghalebikesabi" given-names: "Sahra" - family-names: "Hayes" given-names: "Jamie" - family-names: "Kairouz" given-names: "Peter" - family-names: "McKenna" given-names: "Ryan" - family-names: "McMahan" given-names: "Brendan" - family-names: "Pappu" given-names: "Aneesh" - family-names: "Ponomareva" given-names: "Natalia" - family-names: "Pravilov" given-names: "Mikhail" - family-names: "Rush" given-names: "Keith" - family-names: "Smith" given-names: "Samuel L" - family-names: "Stanforth" given-names: "Robert" title: "JAX-Privacy" version: 0.4.0 date-released: 2025-04-21 url: "https://github.com/google-deepmind/jax_privacy"
GitHub Events
Total
- Issues event: 4
- Watch event: 9
- Delete event: 3
- Member event: 3
- Issue comment event: 2
- Push event: 3
- Pull request review event: 3
- Pull request event: 5
- Fork event: 1
- Create event: 4
Last Year
- Issues event: 4
- Watch event: 9
- Delete event: 3
- Member event: 3
- Issue comment event: 2
- Push event: 3
- Pull request review event: 3
- Pull request event: 5
- Fork event: 1
- Create event: 4
Committers
Last synced: 8 months ago
Top Committers
| Name | Commits | |
|---|---|---|
| Leonard Berrada | l****a | 5 |
| dependabot[bot] | 4****] | 1 |
| Borja de Balle Pigem | b****e@g****m | 1 |
Committer Domains (Top 20 + Academic)
Issues and Pull Requests
Last synced: 6 months ago
All Time
- Total issues: 17
- Total pull requests: 12
- Average time to close issues: 3 months
- Average time to close pull requests: 6 months
- Total issue authors: 13
- Total pull request authors: 6
- Average comments per issue: 2.35
- Average comments per pull request: 0.25
- Merged pull requests: 8
- Bot issues: 0
- Bot pull requests: 2
Past Year
- Issues: 3
- Pull requests: 4
- Average time to close issues: 11 days
- Average time to close pull requests: about 9 hours
- Issue authors: 2
- Pull request authors: 3
- Average comments per issue: 0.67
- Average comments per pull request: 0.0
- Merged pull requests: 3
- Bot issues: 0
- Bot pull requests: 0
Top Authors
Issue Authors
- kamadforge (2)
- Solosneros (2)
- RamSaw (2)
- RoyRin (2)
- heilrahc (1)
- ahasanpour (1)
- shs037 (1)
- steverab (1)
- SabrinaMokhtari (1)
- terranceliu (1)
- Ryan0v0 (1)
- CHAOS-Yang (1)
- mmirmahdi (1)
Pull Request Authors
- lberrada (4)
- bogdan-kulynych (2)
- dependabot[bot] (2)
- BorjaBalle (2)
- RamSaw (1)
- ryan112358 (1)
Top Labels
Issue Labels
Pull Request Labels
Packages
- Total packages: 1
-
Total downloads:
- pypi 23 last-month
- Total dependent packages: 0
- Total dependent repositories: 0
- Total versions: 1
- Total maintainers: 3
pypi.org: jax-privacy
Algorithms for Privacy-Preserving Machine Learning in JAX.
- Homepage: https://github.com/google-deepmind/jax_privacy
- Documentation: https://jax-privacy.readthedocs.io/
- License: Apache 2.0
-
Latest release: 1.0.0
published 7 months ago
Rankings
Maintainers (3)
Dependencies
- absl-py *
- dill *
- dm-haiku *
- jax *
- ml_collections *
- numpy *
- optax *
- scipy *
- tensorflow *
- tensorflow_datasets *
- actions/checkout v2 composite
- actions/setup-python v1 composite