jax-privacy

Algorithms for Privacy-Preserving Machine Learning in JAX

https://github.com/google-deepmind/jax_privacy

Science Score: 54.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Committers with academic emails
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (15.6%) to scientific vocabulary

Keywords from Contributors

interactive ecosystem-modeling mesh interpretability profiles distribution sequences generic projection standardization
Last synced: 6 months ago · JSON representation ·

Repository

Algorithms for Privacy-Preserving Machine Learning in JAX

Basic Info
  • Host: GitHub
  • Owner: google-deepmind
  • License: apache-2.0
  • Language: Python
  • Default Branch: main
  • Homepage:
  • Size: 754 KB
Statistics
  • Stars: 96
  • Watchers: 9
  • Forks: 14
  • Open Issues: 2
  • Releases: 1
Created almost 4 years ago · Last pushed 6 months ago
Metadata Files
Readme Contributing License Citation

README.md

JAX-Privacy: Algorithms for Privacy-Preserving Machine Learning in JAX

| Docs | Library | Installation | Reproducing Results | Citing | Contact

This repository contains:

  • A production-focused API for differentially-private (DP) training of ML models in JAX and Keras.
  • A library of core components for implementing differentially private machine learning algorithms in JAX.
  • A JAX-based machine learning DP pipeline using components from the library to experiment with image classification models.

This code is open-sourced with the main objective of transparency and reproducibility for research purposes, and includes production-focused APIs for differentially private machine learning. Some rough edges should be expected, especially in the research components.

New: Production-Focused JAX Privacy Library

We are excited to introduce a more production-focused JAX Privacy API designed to simplify the development of differentially-private (DP) machine learning models, including Large Language Models (LLMs).

Key Features:

  • Algorithm Support: Currently supports the DP-SGD (Differentially Private Stochastic Gradient Descent) algorithm. We are actively working on incorporating more DP algorithms (e.g. DP-FTRL) in the near future.
  • Framework Integration: The library provides APIs tailored for different JAX-based development experiences:
    • Keras: A high-level API, excellent for common tasks like fine-tuning LLMs. See Keras API simple example and Gemma fine-tuning notebook to get started.
    • Flax Linen: Offers greater flexibility for custom model architectures and training loops, at the cost of some additional boilerplate. See MNIST notebook to get started.
    • Raw JAX: Provides the most low-level control. Recommended for researchers who want to test out new ideas, people who want to use a framework not listed above (e.g. equinox), or people who want a more numpy-like experience. <!-- TODO - b/398715962: add "External Contributions & Design" section, link to readthedocs -->

This new JAX Privacy API aims to provide a more streamlined and robust experience for building DP ML models, complementing the existing research-focused components.

We believe this new API will significantly lower the barrier to implementing DP in your machine learning projects.

Installation

Note: to ensure that your installation is compatible with your local accelerators such as a GPU, we recommend to first follow the corresponding instructions to install JAX.

Option 1: Static Installation

This option is preferred for the purpose of re-using functionalities of our library without modifying them. The library package can be installed by running the following command-line:

pip install git+https://github.com/google-deepmind/jax_privacy

This will not install the training pipeline.

Option 2: Local Installation

This option is preferred to either build on top of our codebase, or to reproduce our results using the training pipeline.

  • The first step is to clone the repository:

git clone https://github.com/google-deepmind/jax_privacy

  • Then the code can be installed. We recommend local installation so modifications to the code are reflected in imports of the package:

cd jax_privacy pip install -e .

Reproducing Results

Unlocking High-Accuracy Differentially Private Image Classification through Scale

Unlocking Accuracy and Fairness in Differentially Private Image Classification

How to Cite This Repository

If you use code from this repository, please cite the following reference:

@software{jax-privacy2022github, author = {Balle, Borja and Berrada, Leonard and Charles, Zachary and Choquette-Choo, Christopher A and De, Soham and Doroshenko, Vadym and Dvijotham, Dj and Galen, Andrew and Ganesh, Arun and Ghalebikesabi, Sahra and Hayes, Jamie and Kairouz, Peter and McKenna, Ryan and McMahan, Brendan and Pappu, Aneesh and Ponomareva, Natalia and Pravilov, Mikhail and Rush, Keith and Smith, Samuel L and Stanforth, Robert}, title = {{JAX}-{P}rivacy: Algorithms for Privacy-Preserving Machine Learning in JAX}, url = {http://github.com/google-deepmind/jax_privacy}, version = {0.4.0}, year = {2025}, }

Contact

If you have any questions or feedback, you can contact us via email: jax-privacy-open-source@google.com.

Acknowledgements

License

All code is made available under the Apache 2.0 License. Model parameters are made available under the Creative Commons Attribution 4.0 International (CC BY 4.0) License.

See https://creativecommons.org/licenses/by/4.0/legalcode for more details.

Disclaimer

This is not an official Google product.

Owner

  • Name: Google DeepMind
  • Login: google-deepmind
  • Kind: organization

Citation (CITATION.cff)

cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- family-names: "Balle"
  given-names: "Borja"
- family-names: "Berrada"
  given-names: "Leonard"
- family-names: "Charles"
  given-names: "Zachary"
- family-names: "Choquette-Choo"
  given-names: "Christopher A"
- family-names: "De"
  given-names: "Soham"
- family-names: "Doroshenko"
  given-names: "Vadym"
- family-names: "Dvijotham"
  given-names: "Dj"
- family-names: "Galen"
  given-names: "Andrew"
- family-names: "Ganesh"
  given-names: "Arun"
- family-names: "Ghalebikesabi"
  given-names: "Sahra"
- family-names: "Hayes"
  given-names: "Jamie"
- family-names: "Kairouz"
  given-names: "Peter"
- family-names: "McKenna"
  given-names: "Ryan"
- family-names: "McMahan"
  given-names: "Brendan"
- family-names: "Pappu"
  given-names: "Aneesh"
- family-names: "Ponomareva"
  given-names: "Natalia"
- family-names: "Pravilov"
  given-names: "Mikhail"
- family-names: "Rush"
  given-names: "Keith"
- family-names: "Smith"
  given-names: "Samuel L"
- family-names: "Stanforth"
  given-names: "Robert"
title: "JAX-Privacy"
version: 0.4.0
date-released: 2025-04-21
url: "https://github.com/google-deepmind/jax_privacy"

GitHub Events

Total
  • Issues event: 4
  • Watch event: 9
  • Delete event: 3
  • Member event: 3
  • Issue comment event: 2
  • Push event: 3
  • Pull request review event: 3
  • Pull request event: 5
  • Fork event: 1
  • Create event: 4
Last Year
  • Issues event: 4
  • Watch event: 9
  • Delete event: 3
  • Member event: 3
  • Issue comment event: 2
  • Push event: 3
  • Pull request review event: 3
  • Pull request event: 5
  • Fork event: 1
  • Create event: 4

Committers

Last synced: 8 months ago

All Time
  • Total Commits: 7
  • Total Committers: 3
  • Avg Commits per committer: 2.333
  • Development Distribution Score (DDS): 0.286
Past Year
  • Commits: 2
  • Committers: 2
  • Avg Commits per committer: 1.0
  • Development Distribution Score (DDS): 0.5
Top Committers
Name Email Commits
Leonard Berrada l****a 5
dependabot[bot] 4****] 1
Borja de Balle Pigem b****e@g****m 1
Committer Domains (Top 20 + Academic)

Issues and Pull Requests

Last synced: 6 months ago

All Time
  • Total issues: 17
  • Total pull requests: 12
  • Average time to close issues: 3 months
  • Average time to close pull requests: 6 months
  • Total issue authors: 13
  • Total pull request authors: 6
  • Average comments per issue: 2.35
  • Average comments per pull request: 0.25
  • Merged pull requests: 8
  • Bot issues: 0
  • Bot pull requests: 2
Past Year
  • Issues: 3
  • Pull requests: 4
  • Average time to close issues: 11 days
  • Average time to close pull requests: about 9 hours
  • Issue authors: 2
  • Pull request authors: 3
  • Average comments per issue: 0.67
  • Average comments per pull request: 0.0
  • Merged pull requests: 3
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • kamadforge (2)
  • Solosneros (2)
  • RamSaw (2)
  • RoyRin (2)
  • heilrahc (1)
  • ahasanpour (1)
  • shs037 (1)
  • steverab (1)
  • SabrinaMokhtari (1)
  • terranceliu (1)
  • Ryan0v0 (1)
  • CHAOS-Yang (1)
  • mmirmahdi (1)
Pull Request Authors
  • lberrada (4)
  • bogdan-kulynych (2)
  • dependabot[bot] (2)
  • BorjaBalle (2)
  • RamSaw (1)
  • ryan112358 (1)
Top Labels
Issue Labels
enhancement (1) good first issue (1)
Pull Request Labels
dependencies (2)

Packages

  • Total packages: 1
  • Total downloads:
    • pypi 23 last-month
  • Total dependent packages: 0
  • Total dependent repositories: 0
  • Total versions: 1
  • Total maintainers: 3
pypi.org: jax-privacy

Algorithms for Privacy-Preserving Machine Learning in JAX.

  • Versions: 1
  • Dependent Packages: 0
  • Dependent Repositories: 0
  • Downloads: 23 Last month
Rankings
Dependent packages count: 8.8%
Average: 29.3%
Dependent repos count: 49.8%
Last synced: 6 months ago

Dependencies

requirements.txt pypi
  • absl-py *
  • dill *
  • dm-haiku *
  • jax *
  • ml_collections *
  • numpy *
  • optax *
  • scipy *
  • tensorflow *
  • tensorflow_datasets *
.github/workflows/ci.yml actions
  • actions/checkout v2 composite
  • actions/setup-python v1 composite
setup.py pypi