https://github.com/aieng-lab/transformer-math-pretraining
Framework to pretrain mathematical aware transformer models using MAMUT datasets
Science Score: 36.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
○CITATION.cff file
-
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
✓Academic publication links
Links to: arxiv.org -
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (13.1%) to scientific vocabulary
Repository
Framework to pretrain mathematical aware transformer models using MAMUT datasets
Basic Info
Statistics
- Stars: 1
- Watchers: 0
- Forks: 0
- Open Issues: 0
- Releases: 0
Metadata Files
README.md
Jonathan Drechsel, Katja Noack, Anja Reusch, Steffen Herbold
Transformer Math Pretraining
Framework to pretrain mathematical aware transformer models, first introduced by MAMUT: A Novel Framework for Modifying Mathematical Formulas for the Generation of Specialized Datasets for Language Model Training.
Installation
1 . Clone the repository:
bash
git clone https://github.com/aieng-lab/transformer-math-pretraining.git
2. Create a Python Environment
Option 1: Using Conda
bash
conda env create -f environment.yml
conda activate transformer_pretraining
Option 2: Using pip
bash
python3 -m venv env_tp
source env_tp/bin/activate
pip install -r requirements.txt
3. Download the datasets
bash
python src/execution/data/download_data.py
4. [Optional] Create a mathematical tokenizer
bash
python src/execution/create_tokenizer.py
This script will first analyze the most common math tokens contained in MT, and then create a mathematical tokenizer for bert-base-cased with 300 additional math tokens (the most frequent 300 math tokens not found in the original tokenizer). Adjust the script to use, e.g., a different base model.
This does not only create the mathematical tokenizer, but also saves a model with randomly initialized weight for these added tokens (models/tokenized).
Pretraining Models
This repository can be used to pretrain mathematical aware transformer models based on the MAMUT-enhanced datasets: Mathematical Formulas (MF), Mathematical Texts (MT), Named Math Formulas (NMF), and Math Formula Retrieval (MFR).
To recreate one of the pretrained models in the MAMUT paper, you can refer to the scripts (e.g., scripts/BERT_MF_MT.sh). To pretrain all models from the MAMUT paper, you can use scripts/mamut.sh.
Please note that this pretraining should be run on a machine with 8 A100 GPUs with at least 40GB of GPU memory each. The default pretraining takes about 12 hours per task.
Note: You need to adjust the
base_dirin each script!
Pretraining Details
The src/execution/training/execute_pretraining script is designed to automate the pretraining of transformer-based models (e.g., BERT) on custom objectives and datasets. It supports both single and multi-objective pretraining (either one-by-one or in parallel/mixed) and offers full control over training parameters via command-line arguments.
The script handles:
- Loading datasets: Expects a local DatasetDict.
- Dataset preprocessing: Adjusts dataset to fit batch sizes and objective-specific needs (e.g., preparing the epoch-wise changing false examples for NMF and MFR)
- Objective Management: Supports primarily the four objectives used for MAMUT (MF, MT, NMF, MFR), however, other objectives are supported as well (MLM, NSP, SOP, and more), but you will need to provide then more information like files (this documentation focuses on the MATH objectives)
- Training: Runs training using an
Executorthat wraps around Hugging Face Transformers and accelerates training across devices. - Saving: After training, the final model and tokenizer are saved.
Parameters (via CLI)
Parameter | Description
---------|------------
--pretraining_obj | Pretraining objective(s) to be used. Can be a single objective or a list of objectives (e.g., MF_MT). If one_by_one is True, the order of the list matters.
--base_bert | Hugging Face identifier or local path of input model
--one_by_one | In case of multiple pretraining tasks, whether to train them ony by one (i.e., finish the first task completly before training the 2nd one), or in a mixed way (changing the task after each batch, if multiple GPUs are used, multiple tasks are used for each optimization step, e.g., one task on four GPUs and the 2nd task on other four GPUs)
--opt_steps | Number of optimization steps for pretraining
--interval_len | Number of optimization steps between evaluation
--batch | Batch size for pretraining per GPU
--lr | Learning rate for pretraining
--warmup | Number of warmup steps for learning rate scheduler
--num_gpus | Number of GPUs used for pretraining
Some Implementation Details
- Specific implementation of the mathematical objectives can be found in
src/pretraining_methods/mlm_like/MLM/prepare.pyandsrc/pretraining_methods/nsp_like/NSP/prepare.py- The mathematical words used for whole word masking for MT can be found in
src/pretraining_methods/mlm_like/MLM/prepare#math_words
- The mathematical words used for whole word masking for MT can be found in
- The
src/data_sets/PreTrainingDataset.pycontains some advanced logic to realize the mixed multi-objective training and changing false examples for NMF and MFR every epoch. To support multi-GPUs, an advanced deterministically randomized mapping is applied only based on the index provided. - The pretraining objectives have different names in the code than in the paper:
MF=MFM=MLM_MATHMT=MTM=MLM_MATH_TEXTNMF=NFIR(Named-Formula-IR)MFR=FFIR(Formula-Formula-IR)
CITATION
If you use this evaluation framework, please cite the following paper:
bibtex
@article{
drechsel2025mamut,
title={{MAMUT}: A Novel Framework for Modifying Mathematical Formulas for the Generation of Specialized Datasets for Language Model Training},
author={Jonathan Drechsel and Anja Reusch and Steffen Herbold},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2025},
url={https://openreview.net/forum?id=khODmRpQEx}
}
Authors
- Katja Noack (original implementation, @katja98)
- Jonathan Drechsel (math adaption, @jdrechsel13)
Owner
- Name: aieng-lab
- Login: aieng-lab
- Kind: organization
- Website: https://www.fim.uni-passau.de/ai-engineering
- Repositories: 1
- Profile: https://github.com/aieng-lab
GitHub organization of the Chair for AI Engineering of the University of Passau
GitHub Events
Total
- Watch event: 1
- Public event: 1
- Push event: 9
Last Year
- Watch event: 1
- Public event: 1
- Push event: 9
Dependencies
- accelerate ==0.17.1
- aiohappyeyeballs ==2.6.1
- aiohttp ==3.11.16
- aiosignal ==1.3.2
- apache-beam ==2.43.0
- async-timeout ==5.0.1
- asyncio ==3.4.3
- attrs ==25.3.0
- blis ==0.7.11
- catalogue ==2.0.10
- certifi ==2025.1.31
- charset-normalizer ==3.4.1
- click ==8.1.8
- cloudpickle ==2.2.1
- comet-ml ==3.33.3
- confection ==0.1.5
- configobj ==5.0.9
- contourpy ==1.2.1
- crcmod ==1.7
- cycler ==0.12.1
- cymem ==2.0.11
- datasets ==3.5.0
- dill ==0.3.8
- disutils ==1.4.32.post2
- docopt ==0.6.2
- dulwich ==0.22.8
- everett ==3.1.0
- fastavro ==1.10.0
- fasteners ==0.19
- filelock ==3.18.0
- fonttools ==4.57.0
- frozenlist ==1.5.0
- fsspec ==2024.12.0
- gensim ==4.3.1
- grpcio ==1.71.0
- hdfs ==2.7.3
- httplib2 ==0.20.4
- huggingface-hub ==0.30.2
- humanize ==4.7.0
- idna ==3.10
- jinja2 ==3.1.6
- joblib ==1.4.2
- jsonschema ==4.23.0
- jsonschema-specifications ==2024.10.1
- kiwisolver ==1.4.8
- langcodes ==3.5.0
- language-data ==1.3.0
- marisa-trie ==1.2.1
- markdown-it-py ==3.0.0
- markupsafe ==3.0.2
- matplotlib ==3.6.3
- mdurl ==0.1.2
- mpmath ==1.3.0
- multidict ==6.4.3
- multiprocess ==0.70.14
- murmurhash ==1.0.12
- networkx ==3.4.2
- nltk ==3.8.1
- numpy ==1.22.4
- nvidia-cublas-cu12 ==12.4.5.8
- nvidia-cuda-cupti-cu12 ==12.4.127
- nvidia-cuda-nvrtc-cu12 ==12.4.127
- nvidia-cuda-runtime-cu12 ==12.4.127
- nvidia-cudnn-cu12 ==9.1.0.70
- nvidia-cufft-cu12 ==11.2.1.3
- nvidia-curand-cu12 ==10.3.5.147
- nvidia-cusolver-cu12 ==11.6.1.9
- nvidia-cusparse-cu12 ==12.3.1.170
- nvidia-cusparselt-cu12 ==0.6.2
- nvidia-nccl-cu12 ==2.21.5
- nvidia-nvjitlink-cu12 ==12.4.127
- nvidia-nvtx-cu12 ==12.4.127
- objsize ==0.5.2
- orjson ==3.10.16
- packaging ==24.2
- pandas ==2.2.3
- pathlib-abc ==0.1.1
- pathy ==0.11.0
- pause ==0.3
- pillow ==11.1.0
- preshed ==3.0.9
- propcache ==0.3.1
- proto-plus ==1.26.1
- protobuf ==3.20.0
- psutil ==5.9.4
- pyarrow ==19.0.1
- pydantic ==1.10.21
- pydot ==1.4.2
- pygments ==2.19.1
- pymongo ==3.13.0
- pynvml ==11.4.1
- pyparsing ==3.2.3
- python-box ==6.1.0
- python-dateutil ==2.9.0.post0
- pytz ==2025.2
- pyyaml ==6.0.2
- referencing ==0.36.2
- regex ==2024.11.6
- requests ==2.32.3
- requests-toolbelt ==1.0.0
- responses ==0.18.0
- rich ==14.0.0
- rpds-py ==0.24.0
- safetensors ==0.5.3
- scikit-learn ==1.2.1
- scipy ==1.12.0
- semantic-version ==2.10.0
- sentencepiece ==0.2.0
- sentry-sdk ==2.25.1
- simplejson ==3.20.1
- six ==1.16.0
- smart-open ==6.4.0
- spacy ==3.4.4
- spacy-legacy ==3.0.12
- spacy-loggers ==1.0.5
- srsly ==2.5.1
- sympy ==1.13.1
- thinc ==8.1.12
- threadpoolctl ==3.6.0
- tokenizers ==0.13.3
- torch ==1.13.1
- tqdm ==4.67.1
- transformers ==4.25.1
- triton ==3.2.0
- typer ==0.7.0
- typing-extensions ==4.13.2
- tzdata ==2025.2
- urllib3 ==1.26.20
- wasabi ==0.10.1
- websocket-client ==1.3.3
- wrapt ==1.17.2
- wurlitzer ==3.1.1
- xxhash ==3.5.0
- yarl ==1.19.0
- zstandard ==0.23.0
- Jinja2 ==3.1.6
- MarkupSafe ==3.0.2
- PyYAML ==6.0.2
- Pygments ==2.19.1
- accelerate ==0.17.1
- aiohappyeyeballs ==2.6.1
- aiohttp ==3.11.16
- aiosignal ==1.3.2
- apache-beam ==2.43.0
- async-timeout ==5.0.1
- asyncio ==3.4.3
- attrs ==25.3.0
- blis ==0.7.11
- catalogue ==2.0.10
- certifi ==2025.1.31
- charset-normalizer ==3.4.1
- click ==8.1.8
- cloudpickle ==2.2.1
- comet-ml ==3.33.3
- confection ==0.1.5
- configobj ==5.0.9
- contourpy ==1.2.1
- crcmod ==1.7
- cycler ==0.12.1
- cymem ==2.0.11
- datasets ==3.5.0
- dill ==0.3.8
- disutils ==1.4.32.post2
- docopt ==0.6.2
- dulwich ==0.22.8
- everett ==3.1.0
- fastavro ==1.10.0
- fasteners ==0.19
- filelock ==3.18.0
- fonttools ==4.57.0
- frozenlist ==1.5.0
- fsspec ==2024.12.0
- gensim ==4.3.1
- grpcio ==1.71.0
- hdfs ==2.7.3
- httplib2 ==0.20.4
- huggingface-hub ==0.30.2
- humanize ==4.7.0
- idna ==3.10
- joblib ==1.4.2
- jsonschema ==4.23.0
- jsonschema-specifications ==2024.10.1
- kiwisolver ==1.4.8
- langcodes ==3.5.0
- language_data ==1.3.0
- marisa-trie ==1.2.1
- markdown-it-py ==3.0.0
- matplotlib ==3.6.3
- mdurl ==0.1.2
- mpmath ==1.3.0
- multidict ==6.4.3
- multiprocess ==0.70.14
- murmurhash ==1.0.12
- networkx ==3.4.2
- nltk ==3.8.1
- numpy ==1.22.4
- nvidia-cublas-cu12 ==12.4.5.8
- nvidia-cuda-cupti-cu12 ==12.4.127
- nvidia-cuda-nvrtc-cu12 ==12.4.127
- nvidia-cuda-runtime-cu12 ==12.4.127
- nvidia-cudnn-cu12 ==9.1.0.70
- nvidia-cufft-cu12 ==11.2.1.3
- nvidia-curand-cu12 ==10.3.5.147
- nvidia-cusolver-cu12 ==11.6.1.9
- nvidia-cusparse-cu12 ==12.3.1.170
- nvidia-cusparselt-cu12 ==0.6.2
- nvidia-nccl-cu12 ==2.21.5
- nvidia-nvjitlink-cu12 ==12.4.127
- nvidia-nvtx-cu12 ==12.4.127
- objsize ==0.5.2
- orjson ==3.10.16
- packaging ==24.2
- pandas ==2.2.3
- pathlib_abc ==0.1.1
- pathy ==0.11.0
- pause ==0.3
- pillow ==11.1.0
- preshed ==3.0.9
- propcache ==0.3.1
- proto-plus ==1.26.1
- protobuf ==3.20.0
- psutil ==5.9.4
- pyarrow ==19.0.1
- pydantic ==1.10.21
- pydot ==1.4.2
- pymongo ==3.13.0
- pynvml ==11.4.1
- pyparsing ==3.2.3
- python-box ==6.1.0
- python-dateutil ==2.9.0.post0
- pytz ==2025.2
- referencing ==0.36.2
- regex ==2024.11.6
- requests ==2.32.3
- requests-toolbelt ==1.0.0
- responses ==0.18.0
- rich ==14.0.0
- rpds-py ==0.24.0
- safetensors ==0.5.3
- scikit-learn ==1.2.1
- scipy ==1.12.0
- semantic-version ==2.10.0
- sentencepiece ==0.2.0
- sentry-sdk ==2.25.1
- simplejson ==3.20.1
- six ==1.16.0
- smart-open ==6.4.0
- spacy ==3.4.4
- spacy-legacy ==3.0.12
- spacy-loggers ==1.0.5
- srsly ==2.5.1
- sympy ==1.13.1
- thinc ==8.1.12
- threadpoolctl ==3.6.0
- tokenizers ==0.13.3
- torch ==1.13.1
- tqdm ==4.67.1
- transformers ==4.25.1
- triton ==3.2.0
- typer ==0.7.0
- typing_extensions ==4.13.2
- tzdata ==2025.2
- urllib3 ==1.26.20
- wasabi ==0.10.1
- websocket-client ==1.3.3
- wrapt ==1.17.2
- wurlitzer ==3.1.1
- xxhash ==3.5.0
- yarl ==1.19.0
- zstandard ==0.23.0