https://github.com/beomi/gemma-easylm
Train GEMMA on TPU/GPU! (Codebase for training Gemma-Ko Series)
Science Score: 10.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
○CITATION.cff file
-
○codemeta.json file
-
○.zenodo.json file
-
○DOI references
-
○Academic publication links
-
✓Committers with academic emails
1 of 12 committers (8.3%) from academic institutions -
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (11.8%) to scientific vocabulary
Keywords
Repository
Train GEMMA on TPU/GPU! (Codebase for training Gemma-Ko Series)
Basic Info
- Host: GitHub
- Owner: Beomi
- License: apache-2.0
- Language: Python
- Default Branch: main
- Homepage: https://huggingface.co/beomi/gemma-ko-7b
- Size: 410 KB
Statistics
- Stars: 47
- Watchers: 2
- Forks: 10
- Open Issues: 0
- Releases: 0
Topics
Metadata Files
README.md
Gemma-EasyLM
This document outlines the integration of the Gemma model into the EasyLM framework, including instructions for training, converting the model format, and serving the model with Gradio.
Training: Integrating HF Flax Weights into EasyLM
Step 1: Consolidate Flax Weights from Hugging Face
You can skip this step with downloading https://huggingface.co/beomi/gemma-ko-7b/resolve/flax-init/flax_model.msgpack
Firstly, concatenate all Flax model weights available at: Hugging Face - Gemma 7B.
Use the following example code to accomplish this:
```python from transformers import GemmaForCausalLM
model = GemmaForCausalLM.frompretrained("google/gemma-7b", torchdtype="auto") model.savepretrained("./flax-concatted", maxshard_size="99GB") ```
This script generates a flax-concatted/flax_model.msgpack file. We will utilize this .msgpack file during the training process.
Step 2: Upload the .msgpack File to Google Cloud Storage (GCS)
Execute the following command to upload the generated .msgpack file to your GCS repository:
bash
gsutil cp ./flax-concatted/flax_model.msgpack gs://YOUR_GCS_REPO_NAME
Step 3: Modify the train.sh Script
Adjust the paths for load_checkpoint, train_dataset.json_dataset.path, and logger.output_dir within the train.sh script to match your setup.
The provided example train.sh script assumes training will be conducted on a TPUv4-64 pod slice.
Step 4: Initiate Training
Execute the training script to start the training process:
./train.sh
Conversion: From EasyLM to Hugging Face Format
Step 1: Retrieve the streaming_train_state File
Download the streaming_train_state file from your GCS repository using the following command:
gsutil cp gs://YOUR_GCS_REPO_NAME/.../streaming_train_state_80000 .
Note: The file name will either be streaming_train_state or streaming_train_state_STEPNO.
Step 2: Update the .stream File Path
In the convert_easylm_stream_to_hf_safetensors.py file, modify the path to the .stream file accordingly:
```python
Modify this line
, param = StreamingCheckpointer.loadtrainstatecheckpoint(loadfrom='trainstateparams::/home/latheledusjp/streamingtrainstate80000') ```
Step 3: Execute the Conversion Script
Run the conversion script to convert the EasyLM model format to Hugging Face's format:
python convert_easylm_stream_to_hf_safetensors.py
Step 4: Verify the Output Files
Check the generated output files in the ./gemma-ko-8.5b-dev directory.
The Flax-version of the weight file can be found in the
./flax-gemma-ko-8bfolder.
Serving the Model with Gradio
To serve the model using Gradio, follow these steps:
cd EasyLM/models/gemma
pip install -r serving_requirements.txt
./serve_test.sh
Original EasyLM Reference
If you found EasyLM useful in your research or applications, please cite using the following BibTeX:
@software{geng2023easylm,
author = {Geng, Xinyang},
title = {EasyLM: A Simple And Scalable Training Framework for Large Language Models},
month = March,
year = 2023,
url = {https://github.com/young-geng/EasyLM}
}
Credits
- The LLaMA implementation is from JAX_llama
- The JAX/Flax GPT-J and RoBERTa implementation are from transformers
- Most of the JAX utilities are from mlxu
- The codebase is heavily inspired by JAXSeq
Owner
- Name: Junbum Lee
- Login: Beomi
- Kind: user
- Location: Seoul, South Korea
- Website: https://junbuml.ee
- Twitter: __Beomi__
- Repositories: 110
- Profile: https://github.com/Beomi
AI/ML GDE @ml-gde. Korean AI/NLP Researcher and creator of multiple Korean PLMs. Focused on advancing Open LLMs.
GitHub Events
Total
- Watch event: 3
- Fork event: 1
Last Year
- Watch event: 3
- Fork event: 1
Committers
Last synced: about 1 year ago
Top Committers
| Name | Commits | |
|---|---|---|
| young-geng | y****y@g****m | 181 |
| haoliu | l****9@g****m | 15 |
| Shuangchi He | 3****t | 3 |
| Junbum Lee | j****n@b****t | 3 |
| Szymon Tworkowski | 4****n | 2 |
| diedinyourthoughts | 1****s | 1 |
| akhilkedia | 1****a | 1 |
| ZYHowell | y****g@c****u | 1 |
| Matthew Dangerfield | m****2@g****m | 1 |
| Julien Salinas | a****l@j****m | 1 |
| Gianluca Detommaso | d****a@g****m | 1 |
| Charlie Snell | s****l@i****n | 1 |
Committer Domains (Top 20 + Academic)
Issues and Pull Requests
Last synced: 10 months ago
All Time
- Total issues: 1
- Total pull requests: 0
- Average time to close issues: 14 minutes
- Average time to close pull requests: N/A
- Total issue authors: 1
- Total pull request authors: 0
- Average comments per issue: 2.0
- Average comments per pull request: 0
- Merged pull requests: 0
- Bot issues: 0
- Bot pull requests: 0
Past Year
- Issues: 0
- Pull requests: 0
- Average time to close issues: N/A
- Average time to close pull requests: N/A
- Issue authors: 0
- Pull request authors: 0
- Average comments per issue: 0
- Average comments per pull request: 0
- Merged pull requests: 0
- Bot issues: 0
- Bot pull requests: 0
Top Authors
Issue Authors
- HeegyuKim (1)
Pull Request Authors
Top Labels
Issue Labels
Pull Request Labels
Dependencies
- fastapi *
- gradio <4
- jax *
- ml_dtypes *
- pydantic *
- uvicorn *
- GitPython ==3.1.42
- Jinja2 ==3.1.3
- MarkupSafe ==2.1.5
- PyYAML ==6.0.1
- Pygments ==2.17.2
- absl-py ==2.1.0
- aiohttp ==3.9.3
- aiosignal ==1.3.1
- appdirs ==1.4.4
- async-timeout ==4.0.3
- attrs ==23.2.0
- cachetools ==5.3.2
- certifi ==2024.2.2
- charset-normalizer ==3.3.2
- chex ==0.1.85
- click ==8.1.7
- cloudpickle ==3.0.0
- contextlib2 ==21.6.0
- decorator ==5.1.1
- docker-pycreds ==0.4.0
- einops ==0.7.0
- etils ==1.5.2
- filelock ==3.13.1
- flax ==0.8.1
- frozenlist ==1.4.1
- fsspec ==2024.2.0
- gcsfs ==2024.2.0
- gitdb ==4.0.11
- google-api-core ==2.17.1
- google-auth ==2.28.1
- google-auth-oauthlib ==1.2.0
- google-cloud-core ==2.4.1
- google-cloud-storage ==2.14.0
- google-crc32c ==1.5.0
- google-resumable-media ==2.7.0
- googleapis-common-protos ==1.62.0
- huggingface-hub ==0.20.3
- idna ==3.6
- importlib-metadata ==7.0.1
- importlib-resources ==6.1.1
- jax ==0.4.24
- jaxlib ==0.4.24
- lxml ==5.0.0
- markdown-it-py ==3.0.0
- mdurl ==0.1.2
- ml-collections ==0.1.1
- ml-dtypes ==0.3.2
- mlxu ==0.1.12
- mpmath ==1.3.0
- msgpack ==1.0.7
- multidict ==6.0.5
- nest-asyncio ==1.6.0
- networkx ==3.2.1
- numpy ==1.26.4
- oauthlib ==3.2.2
- olefile ==0.47
- opt-einsum ==3.3.0
- optax ==0.1.9
- orbax-checkpoint ==0.5.3
- packaging ==23.2
- protobuf ==4.25.3
- psutil ==5.9.8
- pyasn1 ==0.5.1
- pyasn1-modules ==0.3.0
- pyhwp ==0.1b15
- regex ==2023.12.25
- requests ==2.31.0
- requests-oauthlib ==1.3.1
- rich ==13.7.0
- rsa ==4.9
- safetensors ==0.4.2
- scipy ==1.12.0
- sentencepiece ==0.2.0
- sentry-sdk ==1.40.5
- setproctitle ==1.3.3
- six ==1.16.0
- smmap ==5.0.1
- sympy ==1.12
- tensorstore ==0.1.53
- tokenizers ==0.15.2
- toolz ==0.12.1
- torch ==2.2.0
- tqdm ==4.66.2
- transformers ==4.38.1
- typing_extensions ==4.9.0
- urllib3 ==2.2.1
- wandb ==0.16.3
- yarl ==1.9.4
- zipp ==3.17.0