https://github.com/birkhoffg/jax-dataloader

Pytorch-like dataloaders for JAX.

https://github.com/birkhoffg/jax-dataloader

Science Score: 26.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (12.4%) to scientific vocabulary

Keywords

dataloader dataset datasets deep-learning huggingface-datasets jax jax-dataloader pytorch tensorflow
Last synced: 5 months ago · JSON representation

Repository

Pytorch-like dataloaders for JAX.

Basic Info
Statistics
  • Stars: 94
  • Watchers: 3
  • Forks: 3
  • Open Issues: 7
  • Releases: 8
Topics
dataloader dataset datasets deep-learning huggingface-datasets jax jax-dataloader pytorch tensorflow
Created about 3 years ago · Last pushed 9 months ago
Metadata Files
Readme License

README.md

Dataloader for JAX

Python CI
status Docs pypi GitHub
License Downloads

Overview | Installation | Usage | Documentation

Overview

jax_dataloader brings pytorch-like dataloader API to jax. It supports

A minimum jax-dataloader example:

``` python import jax_dataloader as jdl

jdl.manual_seed(1234) # Set the global seed to 1234 for reproducibility

dataloader = jdl.DataLoader( dataset, # Can be a jdl.Dataset or pytorch or huggingface or tensorflow dataset backend='jax', # Use 'jax' backend for loading data batchsize=32, # Batch size shuffle=True, # Shuffle the dataloader every iteration or not droplast=False, # Drop the last batch or not generator=jdl.Generator() # Control the randomness of this dataloader )

batch = next(iter(dataloader)) # iterate next batch ```

Installation

The latest jax-dataloader release can directly be installed from PyPI:

sh pip install jax-dataloader

or install directly from the repository:

sh pip install git+https://github.com/BirkhoffG/jax-dataloader.git

[!NOTE]

We keep jax-dataloader’s dependencies minimum, which only install jax and plum-dispatch (for backend dispatching) when installing. If you wish to use integration of pytorch, huggingface datasets, or tensorflow, we highly recommend manually install those dependencies.

You can also run pip install jax-dataloader[all] to install everything (not recommended).

Usage

jax_dataloader.core.DataLoader follows similar API as the pytorch dataloader.

  • The dataset should be an object of the subclass of jax_dataloader.core.Dataset or torch.utils.data.Dataset or (the huggingface) datasets.Dataset or tf.data.Dataset.
  • The backend should be one of "jax" or "pytorch" or "tensorflow". This argument specifies which backend dataloader to load batches.

Note that not every dataset is compatible with every backend. See the compatibility table below:

| | jdl.Dataset | torch_data.Dataset | tf.data.Dataset | datasets.Dataset | |:---------------|:--------------|:---------------------|:------------------|:-------------------| | "jax" | ✅ | ❌ | ❌ | ✅ | | "pytorch" | ✅ | ✅ | ❌ | ✅ | | "tensorflow" | ✅ | ❌ | ✅ | ✅ |

Using ArrayDataset

The jax_dataloader.core.ArrayDataset is an easy way to wrap multiple jax.numpy.array into one Dataset. For example, we can create an ArrayDataset as follows:

``` python

Create features X and labels y

X = jnp.arange(100).reshape(10, 10) y = jnp.arange(10)

Create an ArrayDataset

arr_ds = jdl.ArrayDataset(X, y) ```

This arr_ds can be loaded by every backends.

``` python

Create a DataLoader from the ArrayDataset via jax backend

dataloader = jdl.DataLoader(arrds, 'jax', batchsize=5, shuffle=True)

Or we can use the pytorch backend

dataloader = jdl.DataLoader(arrds, 'pytorch', batchsize=5, shuffle=True)

Or we can use the tensorflow backend

dataloader = jdl.DataLoader(arrds, 'tensorflow', batchsize=5, shuffle=True) ```

Using Huggingface Datasets

The huggingface datasets is a morden library for downloading, pre-processing, and sharing datasets. jax_dataloader supports directly passing the huggingface datasets.

python from datasets import load_dataset

For example, We load the "squad" dataset from datasets:

python hf_ds = load_dataset("squad")

Then, we can use jax_dataloader to load batches of hf_ds.

``` python

Create a DataLoader from the datasets.Dataset via jax backend

dataloader = jdl.DataLoader(hfds['train'], 'jax', batchsize=5, shuffle=True)

Or we can use the pytorch backend

dataloader = jdl.DataLoader(hfds['train'], 'pytorch', batchsize=5, shuffle=True)

Or we can use the tensorflow backend

dataloader = jdl.DataLoader(hfds['train'], 'tensorflow', batchsize=5, shuffle=True) ```

Using Pytorch Datasets

The pytorch Dataset and its ecosystems (e.g., torchvision, torchtext, torchaudio) supports many built-in datasets. jax_dataloader supports directly passing the pytorch Dataset.

[!NOTE]

Unfortuantely, the pytorch Dataset can only work with backend=pytorch. See the belowing example.

python from torchvision.datasets import MNIST import numpy as np

We load the MNIST dataset from torchvision. The ToNumpy object transforms images to numpy.array.

python pt_ds = MNIST('/tmp/mnist/', download=True, transform=lambda x: np.array(x, dtype=float), train=False)

This pt_ds can only be loaded via "pytorch" dataloaders.

python dataloader = jdl.DataLoader(pt_ds, 'pytorch', batch_size=5, shuffle=True)

Using Tensowflow Datasets

jax_dataloader supports directly passing the tensorflow datasets.

python import tensorflow_datasets as tfds import tensorflow as tf

For instance, we can load the MNIST dataset from tensorflow_datasets

python tf_ds = tfds.load('mnist', split='test', as_supervised=True)

and use jax_dataloader for iterating the dataset.

python dataloader = jdl.DataLoader(tf_ds, 'tensorflow', batch_size=5, shuffle=True)

Owner

  • Name: Hangzhi Guo
  • Login: BirkhoffG
  • Kind: user
  • Company: Penn State University

Ph.D. Student at Penn State University

GitHub Events

Total
  • Issues event: 16
  • Watch event: 36
  • Delete event: 4
  • Issue comment event: 15
  • Push event: 36
  • Pull request review event: 2
  • Pull request event: 9
  • Create event: 5
Last Year
  • Issues event: 16
  • Watch event: 36
  • Delete event: 4
  • Issue comment event: 15
  • Push event: 36
  • Pull request review event: 2
  • Pull request event: 9
  • Create event: 5

Issues and Pull Requests

Last synced: 7 months ago

All Time
  • Total issues: 21
  • Total pull requests: 17
  • Average time to close issues: about 1 month
  • Average time to close pull requests: 6 days
  • Total issue authors: 9
  • Total pull request authors: 2
  • Average comments per issue: 0.95
  • Average comments per pull request: 1.0
  • Merged pull requests: 16
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 4
  • Pull requests: 2
  • Average time to close issues: 4 days
  • Average time to close pull requests: 7 minutes
  • Issue authors: 3
  • Pull request authors: 1
  • Average comments per issue: 1.0
  • Average comments per pull request: 1.0
  • Merged pull requests: 2
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • BirkhoffG (15)
  • murphyk (2)
  • DSilva27 (2)
  • aspannaus (1)
  • Eliacus (1)
  • noah-lowry (1)
  • allen-adastra (1)
  • pluiez (1)
  • Impure-King (1)
  • Devan-Kerman (1)
Pull Request Authors
  • BirkhoffG (26)
  • Devan-Kerman (1)
Top Labels
Issue Labels
enhancement (7) bug (2) v0.1 (2) documentation (1)
Pull Request Labels
enhancement (17) documentation (4) bug (3)

Packages

  • Total packages: 1
  • Total downloads:
    • pypi 8,829 last-month
  • Total dependent packages: 2
  • Total dependent repositories: 1
  • Total versions: 9
  • Total maintainers: 1
pypi.org: jax-dataloader

Dataloader for jax

  • Versions: 9
  • Dependent Packages: 2
  • Dependent Repositories: 1
  • Downloads: 8,829 Last month
Rankings
Dependent packages count: 10.1%
Stargazers count: 13.4%
Average: 16.9%
Forks count: 19.1%
Downloads: 20.2%
Dependent repos count: 21.5%
Maintainers (1)
Last synced: 6 months ago

Dependencies

.github/workflows/deploy.yaml actions
  • actions/checkout v3 composite
  • actions/setup-python v4 composite
  • peaceiris/actions-gh-pages v3 composite
.github/workflows/nbdev.yaml actions
  • actions/checkout v3 composite
  • actions/setup-python v4 composite
setup.py pypi