https://github.com/beomi/gemma-easylm

Train GEMMA on TPU/GPU! (Codebase for training Gemma-Ko Series)

https://github.com/beomi/gemma-easylm

Science Score: 10.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
  • .zenodo.json file
  • DOI references
  • Academic publication links
  • Committers with academic emails
    1 of 12 committers (8.3%) from academic institutions
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (11.8%) to scientific vocabulary

Keywords

easylm flax gemma huggingface jax language-model tpu transformers
Last synced: 5 months ago · JSON representation

Repository

Train GEMMA on TPU/GPU! (Codebase for training Gemma-Ko Series)

Basic Info
Statistics
  • Stars: 47
  • Watchers: 2
  • Forks: 10
  • Open Issues: 0
  • Releases: 0
Topics
easylm flax gemma huggingface jax language-model tpu transformers
Created almost 2 years ago · Last pushed almost 2 years ago
Metadata Files
Readme License

README.md

Gemma-EasyLM

This document outlines the integration of the Gemma model into the EasyLM framework, including instructions for training, converting the model format, and serving the model with Gradio.

Training: Integrating HF Flax Weights into EasyLM

Step 1: Consolidate Flax Weights from Hugging Face

You can skip this step with downloading https://huggingface.co/beomi/gemma-ko-7b/resolve/flax-init/flax_model.msgpack

Firstly, concatenate all Flax model weights available at: Hugging Face - Gemma 7B.

Use the following example code to accomplish this:

```python from transformers import GemmaForCausalLM

model = GemmaForCausalLM.frompretrained("google/gemma-7b", torchdtype="auto") model.savepretrained("./flax-concatted", maxshard_size="99GB") ```

This script generates a flax-concatted/flax_model.msgpack file. We will utilize this .msgpack file during the training process.

Step 2: Upload the .msgpack File to Google Cloud Storage (GCS)

Execute the following command to upload the generated .msgpack file to your GCS repository:

bash gsutil cp ./flax-concatted/flax_model.msgpack gs://YOUR_GCS_REPO_NAME

Step 3: Modify the train.sh Script

Adjust the paths for load_checkpoint, train_dataset.json_dataset.path, and logger.output_dir within the train.sh script to match your setup.

The provided example train.sh script assumes training will be conducted on a TPUv4-64 pod slice.

Step 4: Initiate Training

Execute the training script to start the training process:

./train.sh

Conversion: From EasyLM to Hugging Face Format

Step 1: Retrieve the streaming_train_state File

Download the streaming_train_state file from your GCS repository using the following command:

gsutil cp gs://YOUR_GCS_REPO_NAME/.../streaming_train_state_80000 .

Note: The file name will either be streaming_train_state or streaming_train_state_STEPNO.

Step 2: Update the .stream File Path

In the convert_easylm_stream_to_hf_safetensors.py file, modify the path to the .stream file accordingly:

```python

Modify this line

, param = StreamingCheckpointer.loadtrainstatecheckpoint(loadfrom='trainstateparams::/home/latheledusjp/streamingtrainstate80000') ```

Step 3: Execute the Conversion Script

Run the conversion script to convert the EasyLM model format to Hugging Face's format:

python convert_easylm_stream_to_hf_safetensors.py

Step 4: Verify the Output Files

Check the generated output files in the ./gemma-ko-8.5b-dev directory.

The Flax-version of the weight file can be found in the ./flax-gemma-ko-8b folder.

Serving the Model with Gradio

To serve the model using Gradio, follow these steps:

cd EasyLM/models/gemma pip install -r serving_requirements.txt ./serve_test.sh

Original EasyLM Reference

If you found EasyLM useful in your research or applications, please cite using the following BibTeX: @software{geng2023easylm, author = {Geng, Xinyang}, title = {EasyLM: A Simple And Scalable Training Framework for Large Language Models}, month = March, year = 2023, url = {https://github.com/young-geng/EasyLM} }

Credits

  • The LLaMA implementation is from JAX_llama
  • The JAX/Flax GPT-J and RoBERTa implementation are from transformers
  • Most of the JAX utilities are from mlxu
  • The codebase is heavily inspired by JAXSeq

Owner

  • Name: Junbum Lee
  • Login: Beomi
  • Kind: user
  • Location: Seoul, South Korea

AI/ML GDE @ml-gde. Korean AI/NLP Researcher and creator of multiple Korean PLMs. Focused on advancing Open LLMs.

GitHub Events

Total
  • Watch event: 3
  • Fork event: 1
Last Year
  • Watch event: 3
  • Fork event: 1

Committers

Last synced: about 1 year ago

All Time
  • Total Commits: 211
  • Total Committers: 12
  • Avg Commits per committer: 17.583
  • Development Distribution Score (DDS): 0.142
Past Year
  • Commits: 3
  • Committers: 1
  • Avg Commits per committer: 3.0
  • Development Distribution Score (DDS): 0.0
Top Committers
Name Email Commits
young-geng y****y@g****m 181
haoliu l****9@g****m 15
Shuangchi He 3****t 3
Junbum Lee j****n@b****t 3
Szymon Tworkowski 4****n 2
diedinyourthoughts 1****s 1
akhilkedia 1****a 1
ZYHowell y****g@c****u 1
Matthew Dangerfield m****2@g****m 1
Julien Salinas a****l@j****m 1
Gianluca Detommaso d****a@g****m 1
Charlie Snell s****l@i****n 1
Committer Domains (Top 20 + Academic)

Issues and Pull Requests

Last synced: 10 months ago

All Time
  • Total issues: 1
  • Total pull requests: 0
  • Average time to close issues: 14 minutes
  • Average time to close pull requests: N/A
  • Total issue authors: 1
  • Total pull request authors: 0
  • Average comments per issue: 2.0
  • Average comments per pull request: 0
  • Merged pull requests: 0
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 0
  • Pull requests: 0
  • Average time to close issues: N/A
  • Average time to close pull requests: N/A
  • Issue authors: 0
  • Pull request authors: 0
  • Average comments per issue: 0
  • Average comments per pull request: 0
  • Merged pull requests: 0
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • HeegyuKim (1)
Pull Request Authors
Top Labels
Issue Labels
Pull Request Labels

Dependencies

EasyLM/models/gemma/serving_requirements.txt pypi
  • fastapi *
  • gradio <4
  • jax *
  • ml_dtypes *
  • pydantic *
  • uvicorn *
requirements.txt pypi
  • GitPython ==3.1.42
  • Jinja2 ==3.1.3
  • MarkupSafe ==2.1.5
  • PyYAML ==6.0.1
  • Pygments ==2.17.2
  • absl-py ==2.1.0
  • aiohttp ==3.9.3
  • aiosignal ==1.3.1
  • appdirs ==1.4.4
  • async-timeout ==4.0.3
  • attrs ==23.2.0
  • cachetools ==5.3.2
  • certifi ==2024.2.2
  • charset-normalizer ==3.3.2
  • chex ==0.1.85
  • click ==8.1.7
  • cloudpickle ==3.0.0
  • contextlib2 ==21.6.0
  • decorator ==5.1.1
  • docker-pycreds ==0.4.0
  • einops ==0.7.0
  • etils ==1.5.2
  • filelock ==3.13.1
  • flax ==0.8.1
  • frozenlist ==1.4.1
  • fsspec ==2024.2.0
  • gcsfs ==2024.2.0
  • gitdb ==4.0.11
  • google-api-core ==2.17.1
  • google-auth ==2.28.1
  • google-auth-oauthlib ==1.2.0
  • google-cloud-core ==2.4.1
  • google-cloud-storage ==2.14.0
  • google-crc32c ==1.5.0
  • google-resumable-media ==2.7.0
  • googleapis-common-protos ==1.62.0
  • huggingface-hub ==0.20.3
  • idna ==3.6
  • importlib-metadata ==7.0.1
  • importlib-resources ==6.1.1
  • jax ==0.4.24
  • jaxlib ==0.4.24
  • lxml ==5.0.0
  • markdown-it-py ==3.0.0
  • mdurl ==0.1.2
  • ml-collections ==0.1.1
  • ml-dtypes ==0.3.2
  • mlxu ==0.1.12
  • mpmath ==1.3.0
  • msgpack ==1.0.7
  • multidict ==6.0.5
  • nest-asyncio ==1.6.0
  • networkx ==3.2.1
  • numpy ==1.26.4
  • oauthlib ==3.2.2
  • olefile ==0.47
  • opt-einsum ==3.3.0
  • optax ==0.1.9
  • orbax-checkpoint ==0.5.3
  • packaging ==23.2
  • protobuf ==4.25.3
  • psutil ==5.9.8
  • pyasn1 ==0.5.1
  • pyasn1-modules ==0.3.0
  • pyhwp ==0.1b15
  • regex ==2023.12.25
  • requests ==2.31.0
  • requests-oauthlib ==1.3.1
  • rich ==13.7.0
  • rsa ==4.9
  • safetensors ==0.4.2
  • scipy ==1.12.0
  • sentencepiece ==0.2.0
  • sentry-sdk ==1.40.5
  • setproctitle ==1.3.3
  • six ==1.16.0
  • smmap ==5.0.1
  • sympy ==1.12
  • tensorstore ==0.1.53
  • tokenizers ==0.15.2
  • toolz ==0.12.1
  • torch ==2.2.0
  • tqdm ==4.66.2
  • transformers ==4.38.1
  • typing_extensions ==4.9.0
  • urllib3 ==2.2.1
  • wandb ==0.16.3
  • yarl ==1.9.4
  • zipp ==3.17.0