https://github.com/google/trax

Trax — Deep Learning with Clear Code and Speed

https://github.com/google/trax

Science Score: 46.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Committers with academic emails
    4 of 82 committers (4.9%) from academic institutions
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (13.6%) to scientific vocabulary

Keywords

deep-learning deep-reinforcement-learning jax machine-learning numpy reinforcement-learning transformer

Keywords from Contributors

distributed deep-neural-networks research tpu pypi weather neuralgcm climate distributed-computing tensor
Last synced: 5 months ago · JSON representation

Repository

Trax — Deep Learning with Clear Code and Speed

Basic Info
  • Host: GitHub
  • Owner: google
  • License: apache-2.0
  • Language: Python
  • Default Branch: master
  • Homepage:
  • Size: 162 MB
Statistics
  • Stars: 8,269
  • Watchers: 138
  • Forks: 823
  • Open Issues: 125
  • Releases: 18
Topics
deep-learning deep-reinforcement-learning jax machine-learning numpy reinforcement-learning transformer
Created over 6 years ago · Last pushed 6 months ago
Metadata Files
Readme Contributing License Authors

README.md

Trax — Deep Learning with Clear Code and Speed

train tracks PyPI
version GitHub
Issues GitHub Build Contributions
welcome License Gitter

Trax is an end-to-end library for deep learning that focuses on clear code and speed. It is actively used and maintained in the Google Brain team. This notebook (run it in colab) shows how to use Trax and where you can find more information.

  1. Run a pre-trained Transformer: create a translator in a few lines of code
  2. Features and resources: API docs, where to talk to us, how to open an issue and more
  3. Walkthrough: how Trax works, how to make new models and train on your own data

We welcome contributions to Trax! We welcome PRs with code for new models and layers as well as improvements to our code and documentation. We especially love notebooks that explain how models work and show how to use them to solve problems!

Here are a few example notebooks:-

General Setup

Execute the following cell (once) before running any of the code samples.

```python import os import numpy as np

!pip install -q -U trax import trax ```

1. Run a pre-trained Transformer

Here is how you create an English-German translator in a few lines of code:

```python

Create a Transformer model.

Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin

model = trax.models.Transformer( inputvocabsize=33300, dmodel=512, dff=2048, nheads=8, nencoderlayers=6, ndecoderlayers=6, maxlen=2048, mode='predict')

Initialize using pre-trained weights.

model.initfromfile('gs://trax-ml/models/translation/endewmt32k.pkl.gz', weightsonly=True)

Tokenize a sentence.

sentence = 'It is nice to learn new things today!' tokenized = list(trax.data.tokenize(iter([sentence]), # Operates on streams. vocabdir='gs://trax-ml/vocabs/', vocabfile='ende_32k.subword'))[0]

Decode from the Transformer.

tokenized = tokenized[None, :] # Add batch dimension. tokenizedtranslation = trax.supervised.decoding.autoregressivesample( model, tokenized, temperature=0.0) # Higher temperature: more diverse results.

De-tokenize,

tokenizedtranslation = tokenizedtranslation[0][:-1] # Remove batch and EOS. translation = trax.data.detokenize(tokenizedtranslation, vocabdir='gs://trax-ml/vocabs/', vocabfile='ende32k.subword') print(translation) ```

Es ist schn, heute neue Dinge zu lernen!

2. Features and resources

Trax includes basic models (like ResNet, LSTM, Transformer) and RL algorithms (like REINFORCE, A2C, PPO). It is also actively used for research and includes new models like the Reformer and new RL algorithms like AWR. Trax has bindings to a large number of deep learning datasets, including Tensor2Tensor and TensorFlow datasets.

You can use Trax either as a library from your own python scripts and notebooks or as a binary from the shell, which can be more convenient for training large models. It runs without any changes on CPUs, GPUs and TPUs.

3. Walkthrough

You can learn here how Trax works, how to create new models and how to train them on your own data.

Tensors and Fast Math

The basic units flowing through Trax models are tensors - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations -- numpy. You should take a look at the numpy guide if you don't know how to operate on tensors: Trax also uses the numpy API for that.

In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the trax.fastmath package thanks to its backends -- JAX and TensorFlow numpy.

```python from trax.fastmath import numpy as fastnp trax.fastmath.use_backend('jax') # Can be 'jax' or 'tensorflow-numpy'.

matrix = fastnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) print(f'matrix = \n{matrix}') vector = fastnp.ones(3) print(f'vector = {vector}') product = fastnp.dot(vector, matrix) print(f'product = {product}') tanh = fastnp.tanh(product) print(f'tanh(product) = {tanh}') ```

matrix = 
[[1 2 3]
 [4 5 6]
 [7 8 9]]
vector = [1. 1. 1.]
product = [12. 15. 18.]
tanh(product) = [0.99999994 0.99999994 0.99999994]

Gradients can be calculated using trax.fastmath.grad.

```python def f(x): return 2.0 * x * x

grad_f = trax.fastmath.grad(f)

print(f'grad(2x^2) at 1 = {grad_f(1.0)}') ```

grad(2x^2) at 1 = 4.0

Layers

Layers are basic building blocks of Trax models. You will learn all about them in the layers intro but for now, just take a look at the implementation of one core Trax layer, Embedding:

```python class Embedding(base.Layer): """Trainable layer that maps discrete tokens/IDs to vectors."""

def init(self, vocabsize, dfeature, kernel_initializer=init.RandomNormalInitializer(1.0)): """Returns an embedding layer with given vocabulary size and vector size.

Args:
  vocab_size: Size of the input vocabulary. The layer will assign a unique
      vector to each ID in `range(vocab_size)`.
  d_feature: Dimensionality/depth of the output vectors.
  kernel_initializer: Function that creates (random) initial vectors for
      the embedding.
"""
super().__init__(name=f'Embedding_{vocab_size}_{d_feature}')
self._d_feature = d_feature  # feature dimensionality
self._vocab_size = vocab_size
self._kernel_initializer = kernel_initializer

def forward(self, x): """Returns embedding vectors corresponding to input token IDs.

Args:
  x: Tensor of token IDs.

Returns:
  Tensor of embedding vectors.
"""
return jnp.take(self.weights, x, axis=0, mode='clip')

def initweightsandstate(self, inputsignature): """Returns tensor of newly initialized embedding vectors.""" del inputsignature shapew = (self.vocabsize, self.dfeature) w = self.kernelinitializer(shape_w, self.rng) self.weights = w ```

Layers with trainable weights like Embedding need to be initialized with the signature (shape and dtype) of the input, and then can be run by calling them.

```python from trax import layers as tl

Create an input tensor x.

x = np.arange(15) print(f'x = {x}')

Create the embedding layer.

embedding = tl.Embedding(vocabsize=20, dfeature=32) embedding.init(trax.shapes.signature(x))

Run the layer -- y = embedding(x).

y = embedding(x) print(f'shape of y = {y.shape}') ```

x = [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14]
shape of y = (15, 32)

Models

Models in Trax are built from layers most often using the Serial and Branch combinators. You can read more about those combinators in the layers intro and see the code for many models in trax/models/, e.g., this is how the Transformer Language Model is implemented. Below is an example of how to build a sentiment classification model.

```python model = tl.Serial( tl.Embedding(vocabsize=8192, dfeature=256), tl.Mean(axis=1), # Average on axis 1 (length of sentence). tl.Dense(2), # Classify 2 classes. tl.LogSoftmax() # Produce log-probabilities. )

You can print model structure.

print(model) ```

Serial[
  Embedding_8192_256
  Mean
  Dense_2
  LogSoftmax
]

Data

To train your model, you need data. In Trax, data streams are represented as python iterators, so you can call next(data_stream) and get a tuple, e.g., (inputs, targets). Trax allows you to use TensorFlow Datasets easily and you can also get an iterator from your own text file using the standard open('my_file.txt').

python train_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True)() eval_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=False)() print(next(train_stream)) # See one example.

(b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it.", 0)

Using the trax.data module you can create input processing pipelines, e.g., to tokenize and shuffle your data. You create data pipelines using trax.data.Serial and they are functions that you apply to streams to create processed streams.

python data_pipeline = trax.data.Serial( trax.data.Tokenize(vocab_file='en_8k.subword', keys=[0]), trax.data.Shuffle(), trax.data.FilterByLength(max_length=2048, length_keys=[0]), trax.data.BucketByLength(boundaries=[ 32, 128, 512, 2048], batch_sizes=[256, 64, 16, 4, 1], length_keys=[0]), trax.data.AddLossWeights() ) train_batches_stream = data_pipeline(train_stream) eval_batches_stream = data_pipeline(eval_stream) example_batch = next(train_batches_stream) print(f'shapes = {[x.shape for x in example_batch]}') # Check the shapes.

shapes = [(4, 1024), (4,), (4,)]

Supervised training

When you have the model and the data, use trax.supervised.training to define training and eval tasks and create a training loop. The Trax training loop optimizes training and will create TensorBoard logs and model checkpoints for you.

```python from trax.supervised import training

Training task.

traintask = training.TrainTask( labeleddata=trainbatchesstream, losslayer=tl.WeightedCategoryCrossEntropy(), optimizer=trax.optimizers.Adam(0.01), nstepspercheckpoint=500, )

Evaluaton task.

evaltask = training.EvalTask( labeleddata=evalbatchesstream, metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()], nevalbatches=20 # For less variance in eval numbers. )

Training loop saves checkpoints to output_dir.

outputdir = os.path.expanduser('~/outputdir/') !rm -rf {outputdir} trainingloop = training.Loop(model, traintask, evaltasks=[evaltask], outputdir=output_dir)

Run 2000 steps (batches).

training_loop.run(2000) ```

Step      1: Ran 1 train steps in 0.78 secs
Step      1: train WeightedCategoryCrossEntropy |  1.33800304
Step      1: eval  WeightedCategoryCrossEntropy |  0.71843582
Step      1: eval      WeightedCategoryAccuracy |  0.56562500

Step    500: Ran 499 train steps in 5.77 secs
Step    500: train WeightedCategoryCrossEntropy |  0.62914723
Step    500: eval  WeightedCategoryCrossEntropy |  0.49253047
Step    500: eval      WeightedCategoryAccuracy |  0.74062500

Step   1000: Ran 500 train steps in 5.03 secs
Step   1000: train WeightedCategoryCrossEntropy |  0.42949259
Step   1000: eval  WeightedCategoryCrossEntropy |  0.35451687
Step   1000: eval      WeightedCategoryAccuracy |  0.83750000

Step   1500: Ran 500 train steps in 4.80 secs
Step   1500: train WeightedCategoryCrossEntropy |  0.41843575
Step   1500: eval  WeightedCategoryCrossEntropy |  0.35207348
Step   1500: eval      WeightedCategoryAccuracy |  0.82109375

Step   2000: Ran 500 train steps in 5.35 secs
Step   2000: train WeightedCategoryCrossEntropy |  0.38129005
Step   2000: eval  WeightedCategoryCrossEntropy |  0.33760912
Step   2000: eval      WeightedCategoryAccuracy |  0.85312500

After training the model, run it like any layer to get results.

python example_input = next(eval_batches_stream)[0][0] example_input_str = trax.data.detokenize(example_input, vocab_file='en_8k.subword') print(f'example input_str: {example_input_str}') sentiment_log_probs = model(example_input[None, :]) # Add batch dimension. print(f'Model returned sentiment probabilities: {np.exp(sentiment_log_probs)}')

example input_str: I first saw this when I was a teen in my last year of Junior High. I was riveted to it! I loved the special effects, the fantastic places and the trial-aspect and flashback method of telling the story.<br /><br />Several years later I read the book and while it was interesting and I could definitely see what Swift was trying to say, I think that while it's not as perfect as the book for social commentary, as a story the movie is better. It makes more sense to have it be one long adventure than having Gulliver return after each voyage and making a profit by selling the tiny Lilliput sheep or whatever.<br /><br />It's much more arresting when everyone thinks he's crazy and the sheep DO make a cameo anyway. As a side note, when I saw Laputa I was stunned. It looks very much like the Kingdom of Zeal from the Chrono Trigger video game (1995) that also made me like this mini-series even more.<br /><br />I saw it again about 4 years ago, and realized that I still enjoyed it just as much. Really high quality stuff and began an excellent run of Sweeps mini-series for NBC who followed it up with the solid Merlin and interesting Alice in Wonderland.<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
Model returned sentiment probabilities: [[3.984500e-04 9.996014e-01]]

Owner

  • Name: Google
  • Login: google
  • Kind: organization
  • Email: opensource@google.com
  • Location: United States of America

Google ❤️ Open Source

GitHub Events

Total
  • Watch event: 217
  • Delete event: 5
  • Issue comment event: 3
  • Push event: 26
  • Pull request event: 14
  • Fork event: 23
  • Create event: 9
Last Year
  • Watch event: 217
  • Delete event: 5
  • Issue comment event: 3
  • Push event: 26
  • Pull request event: 14
  • Fork event: 23
  • Create event: 9

Committers

Last synced: 6 months ago

All Time
  • Total Commits: 1,544
  • Total Committers: 82
  • Avg Commits per committer: 18.829
  • Development Distribution Score (DDS): 0.808
Past Year
  • Commits: 9
  • Committers: 4
  • Avg Commits per committer: 2.25
  • Development Distribution Score (DDS): 0.444
Top Committers
Name Email Commits
Jonni Kanerva j****i@g****m 296
Lukasz Kaiser l****r@g****m 283
Afroz Mohiuddin a****m@g****m 236
Trax Team t****t@g****m 120
Henryk Michalewski h****m@g****m 113
Piotr Kozakowski p****i@g****m 77
Sebastian Jaszczur j****r@g****m 56
Peng Wang w****g@g****m 50
Akshay Modi n****i@g****m 48
Saurav Maheshkar s****r@g****m 28
DarrenZhang01 1****6@1****m 17
Peter Hawkins p****s@g****m 17
Szymon Tworkowski 4****n@u****m 16
syzymon s****i@s****l 16
Jake VanderPlas v****s@g****m 12
Katarzyna Kańska k****a@g****m 9
piotrek p****s@g****m 9
jalammar a****r@g****m 9
Andrei Nesterov a****v@g****m 8
Yilei Yang y****g@g****m 8
Paweł Kołodziej p****j@g****m 7
Lukasz Kaiser l****r@g****m 6
Piotr Nawrot p****3@s****l 6
Adam Roberts a****b@g****m 5
Omar Alsaqa o****a@g****m 5
Michał Tyrolski m****i@g****m 5
Piotr Kozakowski p****i@m****l 4
Trax Team n****y@g****m 4
Christian Clauss c****s@m****m 3
Dawid Jamka 5****y@u****m 3
and 52 more...

Issues and Pull Requests

Last synced: 6 months ago

All Time
  • Total issues: 55
  • Total pull requests: 89
  • Average time to close issues: 8 months
  • Average time to close pull requests: 14 days
  • Total issue authors: 51
  • Total pull request authors: 18
  • Average comments per issue: 1.75
  • Average comments per pull request: 0.21
  • Merged pull requests: 36
  • Bot issues: 0
  • Bot pull requests: 64
Past Year
  • Issues: 0
  • Pull requests: 17
  • Average time to close issues: N/A
  • Average time to close pull requests: 2 days
  • Issue authors: 0
  • Pull request authors: 1
  • Average comments per issue: 0
  • Average comments per pull request: 0.0
  • Merged pull requests: 10
  • Bot issues: 0
  • Bot pull requests: 17
Top Authors
Issue Authors
  • mihalt (3)
  • alexm-gc (2)
  • ngoquanghuy99 (2)
  • Nuna7 (1)
  • Elkia-Federation (1)
  • AndriCcos (1)
  • oendnsk675 (1)
  • jonatasgrosman (1)
  • tvjoseph (1)
  • jsearcy1 (1)
  • topekekere (1)
  • ricottablue (1)
  • ras44 (1)
  • LiuZhenshun (1)
  • cmosguy (1)
Pull Request Authors
  • copybara-service[bot] (74)
  • sunvod (5)
  • syzymon (4)
  • SauravMaheshkar (2)
  • lukaszkaiser (1)
  • rug (1)
  • d0rc (1)
  • vsnupoudel (1)
  • mmarcinmichal (1)
  • manifest (1)
  • thoo (1)
  • 0o001 (1)
  • NickDory (1)
  • arvyzukai (1)
  • MaanasArora (1)
Top Labels
Issue Labels
enhancement (2) good first issue (1)
Pull Request Labels
cla: yes (23) ready to pull (11)

Packages

  • Total packages: 2
  • Total downloads:
    • pypi 1,403 last-month
  • Total docker downloads: 117
  • Total dependent packages: 3
    (may contain duplicates)
  • Total dependent repositories: 65
    (may contain duplicates)
  • Total versions: 42
  • Total maintainers: 2
pypi.org: trax

Trax

  • Versions: 24
  • Dependent Packages: 3
  • Dependent Repositories: 64
  • Downloads: 1,403 Last month
  • Docker Downloads: 117
Rankings
Stargazers count: 0.3%
Forks count: 1.5%
Dependent repos count: 1.9%
Average: 2.3%
Dependent packages count: 2.4%
Docker downloads count: 3.5%
Downloads: 4.0%
Maintainers (2)
Last synced: 6 months ago
proxy.golang.org: github.com/google/trax
  • Versions: 18
  • Dependent Packages: 0
  • Dependent Repositories: 1
Rankings
Stargazers count: 0.8%
Forks count: 1.0%
Average: 4.0%
Dependent repos count: 4.7%
Dependent packages count: 9.6%
Last synced: 6 months ago

Dependencies

.github/workflows/build.yaml actions
  • actions/checkout v2 composite
  • actions/setup-python v2 composite
docs/requirements.txt pypi
  • nbsphinx *
setup.py pypi
  • absl-py *
  • funcsigs *
  • gin-config *
  • gym *
  • jax *
  • jaxlib *
  • matplotlib *
  • numpy *
  • psutil *
  • scipy *
  • six *
  • tensorflow-datasets *
  • tensorflow-text *