DiRe - JAX

DiRe - JAX: A JAX based Dimensionality Reduction Algorithm for Large-scale Data - Published in JOSS (2025)

https://github.com/sashakolpakov/dire-jax

Science Score: 93.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
    Found 4 DOI reference(s) in README and JOSS metadata
  • Academic publication links
    Links to: arxiv.org, joss.theoj.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
    Published in Journal of Open Source Software

Keywords

cpu data-science data-visualization dimensionality-reduction embeddings gpu jax machine-learning pca random-projection tpu tsne umap vector-embeddings

Scientific Fields

Mathematics Computer Science - 32% confidence
Last synced: 4 months ago · JSON representation

Repository

DImensionality REduction in JAX

Basic Info
Statistics
  • Stars: 21
  • Watchers: 3
  • Forks: 2
  • Open Issues: 2
  • Releases: 4
Topics
cpu data-science data-visualization dimensionality-reduction embeddings gpu jax machine-learning pca random-projection tpu tsne umap vector-embeddings
Created over 1 year ago · Last pushed 4 months ago
Metadata Files
Readme License

README.md

DiRe-JAX logo

License Python 3.8+ PyPI DOI badge

PyPI Downloads CI Docs Docs Live

A high-performance DImensionality REduction package with JAX

DiRe offers fast dimensionality reduction preserving the global dataset structure, with benchmarks showing competitive performance against UMAP and t-SNE. Built with JAX for efficient computation on CPUs and GPUs.

Quick start

Basic installation (JAX backend only): bash pip install dire-jax

With utilities for benchmarking: bash pip install dire-jax[utils]

Complete installation with utilities: bash pip install dire-jax[all]

Note: For GPU or TPU acceleration, JAX needs to be specifically installed with hardware support. See the JAX documentation for more details on enabling GPU/TPU support.

Example usage: python from dire_jax import DiRe from sklearn.datasets import make_blobs

```python nsamples = 100000 nfeatures = 1000 ncenters = 12 featuresblobs, labelsblobs = makeblobs(nsamples=nsamples, nfeatures=nfeatures, centers=ncenters, randomstate=42)

reducerblobs = DiRe(ncomponents=2, nneighbors=16, init='pca', maxiterlayout=32, mindist=1e-4, spread=1.0, cutoff=4.0, nsampledirs=8, samplesize=16, negratio=32, verbose=False,)

_ = reducerblobs.fittransform(featuresblobs) reducerblobs.visualize(labels=labelsblobs, pointsize=4)

```

The output should look similar to

12 blobs with 100k points in 1k dimensions embedded in dimension 2

Documentation

Please refer to the DiRe API documentation for more instructions.

Project documentation structure: - /docs/ - API documentation and architecture details - /benchmarking/ - Performance benchmarks and scaling results
- /examples/ - Example usage and demos - /tests/ - Test suite and benchmarking notebooks

Working paper

Our working paper is available on the arXiv. Paper

Also, check out the Jupyter notebook with benchmarking results. Open in Colab

Performance Characteristics

DiRe-JAX is optimized for small-medium datasets (<50K points) with excellent CPU performance and GPU acceleration via JAX. Features fully vectorized computation with JIT compilation for optimal performance.

Benchmarking and utilities

For benchmarking utilities and quality metrics: bash pip install dire-jax[utils]

This provides access to dimensionality reduction quality metrics and benchmarking routines. Some utilities use external packages for persistent homology computations which may increase runtime.

Contributing

Please follow the contibuting guide. Thanks!

Acknowledgement

This work is supported by the Google Cloud Research Award number GCP19980904.

Owner

  • Name: Sasha Kolpakov
  • Login: sashakolpakov
  • Kind: user

JOSS Publication

DiRe - JAX: A JAX based Dimensionality Reduction Algorithm for Large-scale Data
Published
June 09, 2025
Volume 10, Issue 110, Page 8264
Authors
Alexander Kolpakov ORCID
University of Austin, Austin TX, USA; akolpakov@uaustin.org
Igor Rivin ORCID
Temple University, Philadelphia PA, USA; rivin@temple.edu
Editor
Neea Rusch ORCID
Tags
JAX dimensionality reduction machine learning persistence homology data visualization

GitHub Events

Total
  • Create event: 18
  • Release event: 2
  • Issues event: 26
  • Watch event: 13
  • Delete event: 8
  • Issue comment event: 26
  • Public event: 1
  • Push event: 148
  • Pull request review event: 1
  • Pull request event: 20
  • Fork event: 1
Last Year
  • Create event: 18
  • Release event: 2
  • Issues event: 26
  • Watch event: 13
  • Delete event: 8
  • Issue comment event: 26
  • Public event: 1
  • Push event: 148
  • Pull request review event: 1
  • Pull request event: 20
  • Fork event: 1

Issues and Pull Requests

Last synced: 4 months ago

All Time
  • Total issues: 12
  • Total pull requests: 12
  • Average time to close issues: 4 days
  • Average time to close pull requests: 18 minutes
  • Total issue authors: 4
  • Total pull request authors: 3
  • Average comments per issue: 2.0
  • Average comments per pull request: 0.08
  • Merged pull requests: 7
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 12
  • Pull requests: 12
  • Average time to close issues: 4 days
  • Average time to close pull requests: 18 minutes
  • Issue authors: 4
  • Pull request authors: 3
  • Average comments per issue: 2.0
  • Average comments per pull request: 0.08
  • Merged pull requests: 7
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • sashakolpakov (6)
  • crhea93 (4)
  • do-me (1)
  • x-tabdeveloping (1)
Pull Request Authors
  • sashakolpakov (9)
  • igorrivin (2)
  • x-tabdeveloping (1)
Top Labels
Issue Labels
enhancement (4) help wanted (1) good first issue (1) invalid (1) documentation (1)
Pull Request Labels

Packages

  • Total packages: 1
  • Total downloads:
    • pypi 154 last-month
  • Total dependent packages: 0
  • Total dependent repositories: 0
  • Total versions: 4
  • Total maintainers: 1
pypi.org: dire-jax

A JAX-based Dimension Reducer

  • Versions: 4
  • Dependent Packages: 0
  • Dependent Repositories: 0
  • Downloads: 154 Last month
Rankings
Dependent packages count: 9.5%
Average: 31.7%
Dependent repos count: 53.8%
Maintainers (1)
Last synced: 4 months ago

Dependencies

dire_jax.egg-info/requires.txt pypi
  • faiss-cpu *
  • fastdtw *
  • jax *
  • kaleido *
  • loguru *
  • numpy *
  • pandas *
  • plotly *
  • pot *
  • pytwed *
  • ripser *
  • scikit-learn *
  • scipy *
  • tqdm *
setup.py pypi
  • fastdtw *
  • jax *
  • kaleido *
  • loguru *
  • numpy *
  • pandas *
  • plotly *
  • pot *
  • pytwed *
  • ripser *
  • scikit-learn *
  • scipy *
  • tqdm *
.github/workflows/pylint.yml actions
  • actions/checkout v4 composite
  • actions/setup-python v3 composite
.github/workflows/deploy_docs.yml actions
  • actions/checkout v3 composite
  • actions/setup-python v4 composite
  • peaceiris/actions-gh-pages v3 composite
.github/workflows/pypi.yml actions
  • actions/checkout v4 composite
  • actions/setup-python v4 composite
docs/requirements.txt pypi
  • fast-twed *
  • fastdtw *
  • jax *
  • loguru *
  • numpy *
  • pandas *
  • persim *
  • plotly *
  • pot *
  • ripser *
  • scikit-learn *
  • scipy *
  • sphinx >=4.0.0
  • sphinx_rtd_theme >=1.0.0
  • tqdm *