https://github.com/google-research/jestimator
Amos optimizer with JEstimator lib.
Science Score: 36.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
○CITATION.cff file
-
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
✓Academic publication links
Links to: arxiv.org -
○Committers with academic emails
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (12.3%) to scientific vocabulary
Keywords
Keywords from Contributors
Repository
Amos optimizer with JEstimator lib.
Basic Info
Statistics
- Stars: 82
- Watchers: 4
- Forks: 6
- Open Issues: 1
- Releases: 1
Topics
Metadata Files
README.md
Amos and JEstimator
This is not an officially supported Google product.
This is the source code for the paper "Amos: An Adam-style Optimizer with Adaptive Weight Decay towards Model-Oriented Scale".
It implements Amos, an optimizer compatible with the
optax library, and JEstimator, a
light-weight library with a tf.Estimator-like interface to manage
T5X-compatible checkpoints for machine
learning programs in JAX, which we use to run
experiments in the paper.
Quickstart
pip install jestimator
It will install the Amos optimizer implemented in the jestimator lib.
Usage of Amos
This implementation of Amos is used with JAX, a high-performance numerical computing library with automatic differentiation, for machine learning research. The API of Amos is compatible with optax, a library of JAX optimizers (hopefully Amos will be integrated into optax in the near future).
In order to demonstrate the usage, we will apply Amos to MNIST. It is based on Flax's official MNIST Example, and you can find the code in a jupyter notebook here.
1. Imports
``` import jax import jax.numpy as jnp # JAX NumPy from jestimator import amos # The Amos optimizer implementation from jestimator import amos_helper # Helper module for Amos
from flax import linen as nn # The Linen API from flax.training import train_state # Useful dataclass to keep train state
import math import tensorflowdatasets as tfds # TFDS for MNIST from sklearn.metrics import accuracyscore ```
2. Load data
``` def get_datasets(): """Load MNIST train and test datasets into memory."""
dsbuilder = tfds.builder('mnist') dsbuilder.downloadandprepare() trainds = tfds.asnumpy(dsbuilder.asdataset(split='train', batchsize=-1)) testds = tfds.asnumpy(dsbuilder.asdataset(split='test', batchsize=-1)) trainds['image'] = jnp.float32(trainds['image']) / 255. testds['image'] = jnp.float32(testds['image']) / 255. return trainds, testds ```
3. Build model
``` class CNN(nn.Module): """A simple CNN model."""
@nn.compact def call(self, x): x = nn.Conv(features=32, kernelsize=(3, 3))(x) x = nn.relu(x) x = nn.avgpool(x, windowshape=(2, 2), strides=(2, 2)) x = nn.Conv(features=64, kernelsize=(3, 3))(x) x = nn.relu(x) x = nn.avgpool(x, windowshape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(features=256)(x) x = nn.relu(x) x = nn.Dense(features=10)(x) return x
def classifyxeloss(self, x, labels): # Labels read from the tfds MNIST are integers from 0 to 9. # Logits are arrays of size 10. logits = self(x) logits = jax.nn.logsoftmax(logits) labels = jnp.expanddims(labels, -1) llh = jnp.takealongaxis(logits, labels, axis=-1) loss = -jnp.sum(llh) return loss ```
4. Create train state
A TrainState object keeps the model parameters and optimizer states, and can
be checkpointed into files.
We create the model and optimizer in this function.
For the optimizer, we use Amos here. The following hyper-parameters are set:
- learning_rate: The global learning rate.
- eta_fn: The model-specific 'eta'.
- shape_fn: Memory reduction setting.
- beta: Rate for running average of gradient squares.
- clip_value: Gradient clipping for stable training.
The global learning rate is usually set to the 1/sqrt(N), where N is the number of batches in the training data. For MNIST, we have 60k training examples and batch size is 32. So learning_rate=1/sqrt(60000/32).
The model-specific 'etafn' requires a function that, given a variable name and shape, returns a float indicating the expected scale of that variable. Hopefully in the near future we will have libraries that can automatically calculate this 'etafn' from the modeling code; but for now we have to specify it manually.
One can use the amoshelper.paramsfnfromassignmap() helper function to create 'etafn' from an assignmap. An assignmap is a dict which maps regex rules to a value or simple Python expression. It will find the first regex rule which matches the name of a variable, and evaluate the Python expression if necessary to return the value. See our example below.
The 'shapefn' similarly requires a function that, given a variable name and shape, returns a reduced shape for the corresponding slot variables. We can use the amoshelper.paramsfnfromassignmap() helper function to create 'shapefn' from an assignmap as well.
'beta' is the exponential decay rate for running average of gradient squares. We set it to 0.98 here.
'clipvalue' is the gradient clipping value, which should match the magnitude of the loss function. If the loss function is a sum of cross-entropy, then we should set 'clipvalue' to the sqrt of the number of labels.
Please refer to our paper for more details of the hyper-parameters.
``` def gettrainstate(rng): model = CNN() dummyx = jnp.ones([1, 28, 28, 1]) params = model.init(rng, dummyx)
etafn = amoshelper.paramsfnfromassignmap( { './bias': 0.5, '.Conv0/kernel': 'sqrt(8/prod(SHAPE[:-1]))', '.*Conv1/kernel': 'sqrt(2/prod(SHAPE[:-1]))', '.Dense_0/kernel': 'sqrt(2/SHAPE[0])', '.Dense1/kernel': 'sqrt(1/SHAPE[0])', }, evalstrvalue=True, ) shapefn = amoshelper.paramsfnfromassignmap( { '.*Conv[01]/kernel': '(1, 1, 1, SHAPE[-1])', '.Dense_0/kernel': '(1, SHAPE[1])', '.': (), }, evalstrvalue=True, ) optimizer = amos.amos( learningrate=1/math.sqrt(60000/32), etafn=etafn, shapefn=shapefn, beta=0.98, clipvalue=math.sqrt(32), ) return trainstate.TrainState.create( applyfn=model.apply, params=params, tx=optimizer) ```
5. Train step
Use JAX’s @jit decorator to just-in-time compile the function for better performance.
@jax.jit
def train_step(batch, state):
grad_fn = jax.grad(state.apply_fn)
grads = grad_fn(
state.params,
batch['image'],
batch['label'],
method=CNN.classify_xe_loss)
return state.apply_gradients(grads=grads)
6. Infer step
Use JAX’s @jit decorator to just-in-time compile the function for better performance.
@jax.jit
def infer_step(batch, state):
logits = state.apply_fn(state.params, batch['image'])
return jnp.argmax(logits, -1)
7. Main
Run the training loop and evaluate on test set.
``` trainds, testds = get_datasets()
rng = jax.random.PRNGKey(0) rng, initrng = jax.random.split(rng) state = gettrainstate(initrng) del init_rng # Must not be used anymore.
numepochs = 9 for epoch in range(1, numepochs + 1): # Use a separate PRNG key to permute image data during shuffling rng, inputrng = jax.random.split(rng) perms = jax.random.permutation(inputrng, 60000) del inputrng perms = perms.reshape((60000 // 32, 32)) for perm in perms: batch = {k: v[perm, ...] for k, v in trainds.items()} state = train_step(batch, state)
pred = jax.deviceget(inferstep(testds, state)) accuracy = accuracyscore(test_ds['label'], pred) print('epoch: %d, test accuracy: %.2f' % (epoch, accuracy * 100)) ```
After 9 epochs, we should get 99.26 test accuracy. If you made it, congrats!
JEstimator
With JEstimator, you can build your model mostly similar to the MNIST example above, but without writing code for the "Main" section; JEstimator will serve as the entry point for your model, automatically handle checkpointing in a train/eval-once/eval-while-training-and-save-the-best/predict mode, and set up profiling, tensorboard, and logging.
In addition, JEstimator supports model partitioning which is required for training very large models across multiple TPU pods. It supports a T5X-compatible checkpoint format that saves and restores checkpoints in a distributed manner, which is suitable for large multi-pod models.
In order to run models with JEstimator, we need to install T5X and FlaxFormer:
``` git clone --branch=main https://github.com/google-research/t5x cd t5x python3 -m pip install -e . cd ..
git clone --branch=main https://github.com/google/flaxformer cd flaxformer pip3 install . cd .. ```
Then, clone this repo to get the JEstimator code:
git clone --branch=main https://github.com/google-research/jestimator
cd jestimator
Now, we can test a toy linear regression model:
PYTHONPATH=. python3 jestimator/models/linear_regression/linear_regression_test.py
MNIST Example in JEstimator
We provide this MNIST Example to demonstrate how to write modeling code with JEstimator. It is much like the example above, but with a big advantage that, a config object is passed around to collect information from global flags and the dataset, in order to dynamically setup modeling. This makes it easier to apply the model to different datasets; for example, one can immediately try the emnist or eurosat datasets simply by changing a command-line argument, without modifying the code.
With the following command, we can start a job to train on MNIST, log every 100 steps, and save the checkpoints to $HOME/experiments/mnist/models:
PYTHONPATH=. python3 jestimator/estimator.py \
--module_imp="jestimator.models.mnist.mnist" \
--module_config="jestimator/models/mnist/mnist.py" \
--train_pattern="tfds://mnist/split=train" \
--model_dir="$HOME/experiments/mnist/models" \
--train_batch_size=32 \
--train_shuffle_buf=4096 \
--train_epochs=9 \
--check_every_steps=100 \
--max_ckpt=20 \
--save_every_steps=1000 \
--module_config.warmup=2000 \
--module_config.amos_beta=0.98
Meanwhile, we can start a job to monitor the $HOME/experiments/mnist/models folder, evaluate on MNIST test set, and save the model with the highest accuracy:
PYTHONPATH=. python3 jestimator/estimator.py \
--module_imp="jestimator.models.mnist.mnist" \
--module_config="jestimator/models/mnist/mnist.py" \
--eval_pattern="tfds://mnist/split=test" \
--model_dir="$HOME/experiments/mnist/models" \
--eval_batch_size=32 \
--mode="eval_wait" \
--check_ckpt_every_secs=1 \
--save_high="test_accuracy"
At the same time, we can start a tensorboard to monitor the process:
tensorboard --logdir $HOME/experiments/mnist/models
LSTM on PTB
We can use the following command to train a single layer LSTM on PTB:
PYTHONPATH=. python3 jestimator/estimator.py \
--module_imp="jestimator.models.lstm.lm" \
--module_config="jestimator/models/lstm/lm.py" \
--module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt" \
--train_pattern="jestimator/models/lstm/ptb/ptb.train.txt" \
--model_dir="$HOME/models/ptb_lstm" \
--train_batch_size=64 \
--train_consecutive=113 \
--train_shuffle_buf=4096 \
--max_train_steps=200000 \
--check_every_steps=1000 \
--max_ckpt=20 \
--module_config.opt_config.optimizer="amos" \
--module_config.opt_config.learning_rate=0.01 \
--module_config.opt_config.beta=0.98 \
--module_config.opt_config.momentum=0.0
and evaluate:
PYTHONPATH=. python3 jestimator/estimator.py \
--module_imp="jestimator.models.lstm.lm" \
--module_config="jestimator/models/lstm/lm.py" \
--module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt" \
--eval_pattern="jestimator/models/lstm/ptb/ptb.valid.txt" \
--model_dir="$HOME/models/ptb_lstm" \
--eval_batch_size=1
It is suitable for running on single-GPU machine.
More JEstimator Models
Here are some simple guides to pre-train and fine-tune BERT-like models, using TPUs on Google Cloud Platform (GCP). One can start with a Web browser with zero setup, by connecting to a Virtual Machine via Google Cloud console, without installing anything locally. If this is the first time, one is covered by enough credits to try the commands by free.
Owner
- Name: Google Research
- Login: google-research
- Kind: organization
- Location: Earth
- Website: https://research.google
- Repositories: 226
- Profile: https://github.com/google-research
GitHub Events
Total
- Watch event: 2
Last Year
- Watch event: 2
Committers
Last synced: 5 months ago
Top Committers
| Name | Commits | |
|---|---|---|
| Ran Tian | t****n@g****m | 27 |
| Peter Hawkins | p****s@g****m | 5 |
| The jestimator Authors | n****y@g****m | 5 |
| Jake VanderPlas | v****s@g****m | 2 |
| Marcus Chiam | m****m@g****m | 1 |
| Yash Katariya | y****a@g****m | 1 |
Committer Domains (Top 20 + Academic)
Issues and Pull Requests
Last synced: 6 months ago
All Time
- Total issues: 3
- Total pull requests: 32
- Average time to close issues: 3 days
- Average time to close pull requests: 6 days
- Total issue authors: 3
- Total pull request authors: 1
- Average comments per issue: 1.67
- Average comments per pull request: 0.16
- Merged pull requests: 9
- Bot issues: 0
- Bot pull requests: 32
Past Year
- Issues: 0
- Pull requests: 0
- Average time to close issues: N/A
- Average time to close pull requests: N/A
- Issue authors: 0
- Pull request authors: 0
- Average comments per issue: 0
- Average comments per pull request: 0
- Merged pull requests: 0
- Bot issues: 0
- Bot pull requests: 0
Top Authors
Issue Authors
- pczzy (1)
- yhtang (1)
- rpeloff-id (1)
Pull Request Authors
- copybara-service[bot] (33)
Top Labels
Issue Labels
Pull Request Labels
Packages
- Total packages: 1
-
Total downloads:
- pypi 94 last-month
- Total dependent packages: 0
- Total dependent repositories: 4
- Total versions: 5
- Total maintainers: 1
pypi.org: jestimator
Implementation of the Amos optimizer from the JEstimator lib.
- Documentation: https://jestimator.readthedocs.io/
- License: Apache Software License
-
Latest release: 0.3.3
published about 3 years ago
Rankings
Maintainers (1)
Dependencies
- actions/checkout v3 composite
- actions/setup-python v4 composite
- etils-actions/pypi-auto-publish v1 composite
- flax *