https://github.com/chijames/gst

https://github.com/chijames/gst

Science Score: 10.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
  • .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (6.5%) to scientific vocabulary
Last synced: 7 months ago · JSON representation

Repository

Basic Info
  • Host: GitHub
  • Owner: chijames
  • Language: Python
  • Default Branch: master
  • Size: 4.01 MB
Statistics
  • Stars: 8
  • Watchers: 1
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Created almost 4 years ago · Last pushed almost 4 years ago
Metadata Files
Readme

README.md

Gapped Straight-Through Estimator

PyTorch implementation of Gapped Straight-Through Estimator (GST) with experiments on MNIST-VAE and ListOps. We compare our proposed GST estimator with several discrete random variable estimators including Straight-Through Gumbel-Softmax (STGS) and Rao-Blackwellized Straight-Through Gumbel-Softmax (rao_gumbel).

Installation

We recommend using Anaconda with the following commands: conda create -n GST python=3.8 conda activate GST conda install pytorch==1.11.0 torchvision==0.12.0 cudatoolkit=10.2 -c pytorch pip install -r requirements.txt

MNIST-VAE Experiment

Configurations

--mode (default: gumbel) selects the estimators. Possible choices are gumbel, rao_gumbel, gst-1.0, and gst-p.

--temperature (default: 1.0) controls the temperature of the softmax function for the soft sample.

--hard (default: True) gives hard samples using the straight-through trick; otherwise, soft samples are generated.

Example 1: train STGS at temperature 1.0 python gumbel_softmax_vae.py --mode gumbel --temperature 1.0

Example 2: train GST-1.0 at temperature 0.5 python gumbel_softmax_vae.py --mode gst-1.0 --temperature 0.5

ListOps Experiment

Configurations

Example (a): train raogumbel at temperature 0.1 ``` python -m nlp.train --word-dim 300 --hidden-dim 300 --clf-hidden-dim 300 --clf-num-layers 1 --batch-size 16 --max-epoch 20 --save-dir ./checkpointlistops --device cuda --pretrained glove.840B.300d --leaf-rnn --dropout 0.5 --lower --mode rao_gumbel --task listops --temperature 0.1 ```

Example (b): train GST-1.0 at temperature 0.1 python -m nlp.train --word-dim 300 --hidden-dim 300 --clf-hidden-dim 300 --clf-num-layers 1 --batch-size 16 --max-epoch 20 --save-dir ./checkpoint_listops --device cuda --pretrained glove.840B.300d --leaf-rnn --dropout 0.5 --lower --mode gst-1.0 --task listops --temperature 0.1

Owner

  • Login: chijames
  • Kind: user

GitHub Events

Total
Last Year

Dependencies

requirements.txt pypi
  • protobuf ==3.17.3
  • tensorboardX ==2.4
  • torchtext ==0.6.0