https://github.com/birkhoffg/jax-dataloader
Pytorch-like dataloaders for JAX.
Science Score: 26.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
○CITATION.cff file
-
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
○Academic publication links
-
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (12.4%) to scientific vocabulary
Keywords
Repository
Pytorch-like dataloaders for JAX.
Basic Info
- Host: GitHub
- Owner: BirkhoffG
- License: apache-2.0
- Language: Jupyter Notebook
- Default Branch: main
- Homepage: https://birkhoffg.github.io/jax-dataloader/
- Size: 1.02 MB
Statistics
- Stars: 94
- Watchers: 3
- Forks: 3
- Open Issues: 7
- Releases: 8
Topics
Metadata Files
README.md
Dataloader for JAX
Overview | Installation | Usage | Documentation
Overview
jax_dataloader brings pytorch-like dataloader API to jax. It
supports
4 datasets to download and pre-process data:
3 backends to iteratively load batches:
A minimum jax-dataloader example:
``` python import jax_dataloader as jdl
jdl.manual_seed(1234) # Set the global seed to 1234 for reproducibility
dataloader = jdl.DataLoader( dataset, # Can be a jdl.Dataset or pytorch or huggingface or tensorflow dataset backend='jax', # Use 'jax' backend for loading data batchsize=32, # Batch size shuffle=True, # Shuffle the dataloader every iteration or not droplast=False, # Drop the last batch or not generator=jdl.Generator() # Control the randomness of this dataloader )
batch = next(iter(dataloader)) # iterate next batch ```
Installation
The latest jax-dataloader release can directly be installed from PyPI:
sh
pip install jax-dataloader
or install directly from the repository:
sh
pip install git+https://github.com/BirkhoffG/jax-dataloader.git
[!NOTE]
We keep
jax-dataloader’s dependencies minimum, which only installjaxandplum-dispatch(for backend dispatching) when installing. If you wish to use integration ofpytorch, huggingfacedatasets, ortensorflow, we highly recommend manually install those dependencies.You can also run
pip install jax-dataloader[all]to install everything (not recommended).
Usage
jax_dataloader.core.DataLoader
follows similar API as the pytorch dataloader.
- The
datasetshould be an object of the subclass ofjax_dataloader.core.Datasetortorch.utils.data.Datasetor (the huggingface)datasets.Datasetortf.data.Dataset. - The
backendshould be one of"jax"or"pytorch"or"tensorflow". This argument specifies which backend dataloader to load batches.
Note that not every dataset is compatible with every backend. See the compatibility table below:
| | jdl.Dataset | torch_data.Dataset | tf.data.Dataset | datasets.Dataset |
|:---------------|:--------------|:---------------------|:------------------|:-------------------|
| "jax" | ✅ | ❌ | ❌ | ✅ |
| "pytorch" | ✅ | ✅ | ❌ | ✅ |
| "tensorflow" | ✅ | ❌ | ✅ | ✅ |
Using ArrayDataset
The jax_dataloader.core.ArrayDataset is an easy way to wrap multiple
jax.numpy.array into one Dataset. For example, we can create an
ArrayDataset
as follows:
``` python
Create features X and labels y
X = jnp.arange(100).reshape(10, 10) y = jnp.arange(10)
Create an ArrayDataset
arr_ds = jdl.ArrayDataset(X, y) ```
This arr_ds can be loaded by every backends.
``` python
Create a DataLoader from the ArrayDataset via jax backend
dataloader = jdl.DataLoader(arrds, 'jax', batchsize=5, shuffle=True)
Or we can use the pytorch backend
dataloader = jdl.DataLoader(arrds, 'pytorch', batchsize=5, shuffle=True)
Or we can use the tensorflow backend
dataloader = jdl.DataLoader(arrds, 'tensorflow', batchsize=5, shuffle=True) ```
Using Huggingface Datasets
The huggingface datasets is a
morden library for downloading, pre-processing, and sharing datasets.
jax_dataloader supports directly passing the huggingface datasets.
python
from datasets import load_dataset
For example, We load the "squad" dataset from datasets:
python
hf_ds = load_dataset("squad")
Then, we can use jax_dataloader to load batches of hf_ds.
``` python
Create a DataLoader from the datasets.Dataset via jax backend
dataloader = jdl.DataLoader(hfds['train'], 'jax', batchsize=5, shuffle=True)
Or we can use the pytorch backend
dataloader = jdl.DataLoader(hfds['train'], 'pytorch', batchsize=5, shuffle=True)
Or we can use the tensorflow backend
dataloader = jdl.DataLoader(hfds['train'], 'tensorflow', batchsize=5, shuffle=True) ```
Using Pytorch Datasets
The pytorch Dataset and its
ecosystems (e.g.,
torchvision,
torchtext,
torchaudio) supports many
built-in datasets. jax_dataloader supports directly passing the
pytorch Dataset.
[!NOTE]
Unfortuantely, the pytorch Dataset can only work with
backend=pytorch. See the belowing example.
python
from torchvision.datasets import MNIST
import numpy as np
We load the MNIST dataset from torchvision. The ToNumpy object
transforms images to numpy.array.
python
pt_ds = MNIST('/tmp/mnist/', download=True, transform=lambda x: np.array(x, dtype=float), train=False)
This pt_ds can only be loaded via "pytorch" dataloaders.
python
dataloader = jdl.DataLoader(pt_ds, 'pytorch', batch_size=5, shuffle=True)
Using Tensowflow Datasets
jax_dataloader supports directly passing the tensorflow
datasets.
python
import tensorflow_datasets as tfds
import tensorflow as tf
For instance, we can load the MNIST dataset from tensorflow_datasets
python
tf_ds = tfds.load('mnist', split='test', as_supervised=True)
and use jax_dataloader for iterating the dataset.
python
dataloader = jdl.DataLoader(tf_ds, 'tensorflow', batch_size=5, shuffle=True)
Owner
- Name: Hangzhi Guo
- Login: BirkhoffG
- Kind: user
- Company: Penn State University
- Website: https://birkhoffg.github.io
- Twitter: BirkhoffGuo
- Repositories: 4
- Profile: https://github.com/BirkhoffG
Ph.D. Student at Penn State University
GitHub Events
Total
- Issues event: 16
- Watch event: 36
- Delete event: 4
- Issue comment event: 15
- Push event: 36
- Pull request review event: 2
- Pull request event: 9
- Create event: 5
Last Year
- Issues event: 16
- Watch event: 36
- Delete event: 4
- Issue comment event: 15
- Push event: 36
- Pull request review event: 2
- Pull request event: 9
- Create event: 5
Issues and Pull Requests
Last synced: 7 months ago
All Time
- Total issues: 21
- Total pull requests: 17
- Average time to close issues: about 1 month
- Average time to close pull requests: 6 days
- Total issue authors: 9
- Total pull request authors: 2
- Average comments per issue: 0.95
- Average comments per pull request: 1.0
- Merged pull requests: 16
- Bot issues: 0
- Bot pull requests: 0
Past Year
- Issues: 4
- Pull requests: 2
- Average time to close issues: 4 days
- Average time to close pull requests: 7 minutes
- Issue authors: 3
- Pull request authors: 1
- Average comments per issue: 1.0
- Average comments per pull request: 1.0
- Merged pull requests: 2
- Bot issues: 0
- Bot pull requests: 0
Top Authors
Issue Authors
- BirkhoffG (15)
- murphyk (2)
- DSilva27 (2)
- aspannaus (1)
- Eliacus (1)
- noah-lowry (1)
- allen-adastra (1)
- pluiez (1)
- Impure-King (1)
- Devan-Kerman (1)
Pull Request Authors
- BirkhoffG (26)
- Devan-Kerman (1)
Top Labels
Issue Labels
Pull Request Labels
Packages
- Total packages: 1
-
Total downloads:
- pypi 8,829 last-month
- Total dependent packages: 2
- Total dependent repositories: 1
- Total versions: 9
- Total maintainers: 1
pypi.org: jax-dataloader
Dataloader for jax
- Homepage: https://github.com/birkhoffg/jax-dataloader
- Documentation: https://jax-dataloader.readthedocs.io/
- License: Apache Software License 2.0
-
Latest release: 0.1.3
published over 1 year ago
Rankings
Maintainers (1)
Dependencies
- actions/checkout v3 composite
- actions/setup-python v4 composite
- peaceiris/actions-gh-pages v3 composite
- actions/checkout v3 composite
- actions/setup-python v4 composite