https://github.com/aieng-lab/transformer-math-pretraining

Framework to pretrain mathematical aware transformer models using MAMUT datasets

https://github.com/aieng-lab/transformer-math-pretraining

Science Score: 36.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (13.1%) to scientific vocabulary
Last synced: 10 months ago · JSON representation

Repository

Framework to pretrain mathematical aware transformer models using MAMUT datasets

Basic Info
  • Host: GitHub
  • Owner: aieng-lab
  • License: apache-2.0
  • Language: Python
  • Default Branch: main
  • Homepage:
  • Size: 411 KB
Statistics
  • Stars: 1
  • Watchers: 0
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Created about 1 year ago · Last pushed about 1 year ago
Metadata Files
Readme License

README.md

Jonathan Drechsel, Katja Noack, Anja Reusch, Steffen Herbold

arXiv

Transformer Math Pretraining

Framework to pretrain mathematical aware transformer models, first introduced by MAMUT: A Novel Framework for Modifying Mathematical Formulas for the Generation of Specialized Datasets for Language Model Training.

Installation

1 . Clone the repository:

bash git clone https://github.com/aieng-lab/transformer-math-pretraining.git

2. Create a Python Environment

Option 1: Using Conda

bash conda env create -f environment.yml conda activate transformer_pretraining

Option 2: Using pip

bash python3 -m venv env_tp source env_tp/bin/activate pip install -r requirements.txt

3. Download the datasets

bash python src/execution/data/download_data.py

4. [Optional] Create a mathematical tokenizer

bash python src/execution/create_tokenizer.py This script will first analyze the most common math tokens contained in MT, and then create a mathematical tokenizer for bert-base-cased with 300 additional math tokens (the most frequent 300 math tokens not found in the original tokenizer). Adjust the script to use, e.g., a different base model.

This does not only create the mathematical tokenizer, but also saves a model with randomly initialized weight for these added tokens (models/tokenized).

Pretraining Models

This repository can be used to pretrain mathematical aware transformer models based on the MAMUT-enhanced datasets: Mathematical Formulas (MF), Mathematical Texts (MT), Named Math Formulas (NMF), and Math Formula Retrieval (MFR). To recreate one of the pretrained models in the MAMUT paper, you can refer to the scripts (e.g., scripts/BERT_MF_MT.sh). To pretrain all models from the MAMUT paper, you can use scripts/mamut.sh. Please note that this pretraining should be run on a machine with 8 A100 GPUs with at least 40GB of GPU memory each. The default pretraining takes about 12 hours per task.

Note: You need to adjust the base_dir in each script!

Pretraining Details

The src/execution/training/execute_pretraining script is designed to automate the pretraining of transformer-based models (e.g., BERT) on custom objectives and datasets. It supports both single and multi-objective pretraining (either one-by-one or in parallel/mixed) and offers full control over training parameters via command-line arguments.

The script handles:

  • Loading datasets: Expects a local DatasetDict.
  • Dataset preprocessing: Adjusts dataset to fit batch sizes and objective-specific needs (e.g., preparing the epoch-wise changing false examples for NMF and MFR)
  • Objective Management: Supports primarily the four objectives used for MAMUT (MF, MT, NMF, MFR), however, other objectives are supported as well (MLM, NSP, SOP, and more), but you will need to provide then more information like files (this documentation focuses on the MATH objectives)
  • Training: Runs training using an Executor that wraps around Hugging Face Transformers and accelerates training across devices.
  • Saving: After training, the final model and tokenizer are saved.

Parameters (via CLI)

Parameter | Description ---------|------------ --pretraining_obj | Pretraining objective(s) to be used. Can be a single objective or a list of objectives (e.g., MF_MT). If one_by_one is True, the order of the list matters. --base_bert | Hugging Face identifier or local path of input model --one_by_one | In case of multiple pretraining tasks, whether to train them ony by one (i.e., finish the first task completly before training the 2nd one), or in a mixed way (changing the task after each batch, if multiple GPUs are used, multiple tasks are used for each optimization step, e.g., one task on four GPUs and the 2nd task on other four GPUs) --opt_steps | Number of optimization steps for pretraining --interval_len | Number of optimization steps between evaluation --batch | Batch size for pretraining per GPU --lr | Learning rate for pretraining --warmup | Number of warmup steps for learning rate scheduler --num_gpus | Number of GPUs used for pretraining

Some Implementation Details

  • Specific implementation of the mathematical objectives can be found in src/pretraining_methods/mlm_like/MLM/prepare.py and src/pretraining_methods/nsp_like/NSP/prepare.py
    • The mathematical words used for whole word masking for MT can be found in src/pretraining_methods/mlm_like/MLM/prepare#math_words
  • The src/data_sets/PreTrainingDataset.py contains some advanced logic to realize the mixed multi-objective training and changing false examples for NMF and MFR every epoch. To support multi-GPUs, an advanced deterministically randomized mapping is applied only based on the index provided.
  • The pretraining objectives have different names in the code than in the paper:
    • MF = MFM = MLM_MATH
    • MT = MTM = MLM_MATH_TEXT
    • NMF = NFIR (Named-Formula-IR)
    • MFR = FFIR (Formula-Formula-IR)

CITATION

If you use this evaluation framework, please cite the following paper: bibtex @article{ drechsel2025mamut, title={{MAMUT}: A Novel Framework for Modifying Mathematical Formulas for the Generation of Specialized Datasets for Language Model Training}, author={Jonathan Drechsel and Anja Reusch and Steffen Herbold}, journal={Transactions on Machine Learning Research}, issn={2835-8856}, year={2025}, url={https://openreview.net/forum?id=khODmRpQEx} }

Authors

Owner

  • Name: aieng-lab
  • Login: aieng-lab
  • Kind: organization

GitHub organization of the Chair for AI Engineering of the University of Passau

GitHub Events

Total
  • Watch event: 1
  • Public event: 1
  • Push event: 9
Last Year
  • Watch event: 1
  • Public event: 1
  • Push event: 9

Dependencies

environment.yml pypi
  • accelerate ==0.17.1
  • aiohappyeyeballs ==2.6.1
  • aiohttp ==3.11.16
  • aiosignal ==1.3.2
  • apache-beam ==2.43.0
  • async-timeout ==5.0.1
  • asyncio ==3.4.3
  • attrs ==25.3.0
  • blis ==0.7.11
  • catalogue ==2.0.10
  • certifi ==2025.1.31
  • charset-normalizer ==3.4.1
  • click ==8.1.8
  • cloudpickle ==2.2.1
  • comet-ml ==3.33.3
  • confection ==0.1.5
  • configobj ==5.0.9
  • contourpy ==1.2.1
  • crcmod ==1.7
  • cycler ==0.12.1
  • cymem ==2.0.11
  • datasets ==3.5.0
  • dill ==0.3.8
  • disutils ==1.4.32.post2
  • docopt ==0.6.2
  • dulwich ==0.22.8
  • everett ==3.1.0
  • fastavro ==1.10.0
  • fasteners ==0.19
  • filelock ==3.18.0
  • fonttools ==4.57.0
  • frozenlist ==1.5.0
  • fsspec ==2024.12.0
  • gensim ==4.3.1
  • grpcio ==1.71.0
  • hdfs ==2.7.3
  • httplib2 ==0.20.4
  • huggingface-hub ==0.30.2
  • humanize ==4.7.0
  • idna ==3.10
  • jinja2 ==3.1.6
  • joblib ==1.4.2
  • jsonschema ==4.23.0
  • jsonschema-specifications ==2024.10.1
  • kiwisolver ==1.4.8
  • langcodes ==3.5.0
  • language-data ==1.3.0
  • marisa-trie ==1.2.1
  • markdown-it-py ==3.0.0
  • markupsafe ==3.0.2
  • matplotlib ==3.6.3
  • mdurl ==0.1.2
  • mpmath ==1.3.0
  • multidict ==6.4.3
  • multiprocess ==0.70.14
  • murmurhash ==1.0.12
  • networkx ==3.4.2
  • nltk ==3.8.1
  • numpy ==1.22.4
  • nvidia-cublas-cu12 ==12.4.5.8
  • nvidia-cuda-cupti-cu12 ==12.4.127
  • nvidia-cuda-nvrtc-cu12 ==12.4.127
  • nvidia-cuda-runtime-cu12 ==12.4.127
  • nvidia-cudnn-cu12 ==9.1.0.70
  • nvidia-cufft-cu12 ==11.2.1.3
  • nvidia-curand-cu12 ==10.3.5.147
  • nvidia-cusolver-cu12 ==11.6.1.9
  • nvidia-cusparse-cu12 ==12.3.1.170
  • nvidia-cusparselt-cu12 ==0.6.2
  • nvidia-nccl-cu12 ==2.21.5
  • nvidia-nvjitlink-cu12 ==12.4.127
  • nvidia-nvtx-cu12 ==12.4.127
  • objsize ==0.5.2
  • orjson ==3.10.16
  • packaging ==24.2
  • pandas ==2.2.3
  • pathlib-abc ==0.1.1
  • pathy ==0.11.0
  • pause ==0.3
  • pillow ==11.1.0
  • preshed ==3.0.9
  • propcache ==0.3.1
  • proto-plus ==1.26.1
  • protobuf ==3.20.0
  • psutil ==5.9.4
  • pyarrow ==19.0.1
  • pydantic ==1.10.21
  • pydot ==1.4.2
  • pygments ==2.19.1
  • pymongo ==3.13.0
  • pynvml ==11.4.1
  • pyparsing ==3.2.3
  • python-box ==6.1.0
  • python-dateutil ==2.9.0.post0
  • pytz ==2025.2
  • pyyaml ==6.0.2
  • referencing ==0.36.2
  • regex ==2024.11.6
  • requests ==2.32.3
  • requests-toolbelt ==1.0.0
  • responses ==0.18.0
  • rich ==14.0.0
  • rpds-py ==0.24.0
  • safetensors ==0.5.3
  • scikit-learn ==1.2.1
  • scipy ==1.12.0
  • semantic-version ==2.10.0
  • sentencepiece ==0.2.0
  • sentry-sdk ==2.25.1
  • simplejson ==3.20.1
  • six ==1.16.0
  • smart-open ==6.4.0
  • spacy ==3.4.4
  • spacy-legacy ==3.0.12
  • spacy-loggers ==1.0.5
  • srsly ==2.5.1
  • sympy ==1.13.1
  • thinc ==8.1.12
  • threadpoolctl ==3.6.0
  • tokenizers ==0.13.3
  • torch ==1.13.1
  • tqdm ==4.67.1
  • transformers ==4.25.1
  • triton ==3.2.0
  • typer ==0.7.0
  • typing-extensions ==4.13.2
  • tzdata ==2025.2
  • urllib3 ==1.26.20
  • wasabi ==0.10.1
  • websocket-client ==1.3.3
  • wrapt ==1.17.2
  • wurlitzer ==3.1.1
  • xxhash ==3.5.0
  • yarl ==1.19.0
  • zstandard ==0.23.0
requirements.txt pypi
  • Jinja2 ==3.1.6
  • MarkupSafe ==3.0.2
  • PyYAML ==6.0.2
  • Pygments ==2.19.1
  • accelerate ==0.17.1
  • aiohappyeyeballs ==2.6.1
  • aiohttp ==3.11.16
  • aiosignal ==1.3.2
  • apache-beam ==2.43.0
  • async-timeout ==5.0.1
  • asyncio ==3.4.3
  • attrs ==25.3.0
  • blis ==0.7.11
  • catalogue ==2.0.10
  • certifi ==2025.1.31
  • charset-normalizer ==3.4.1
  • click ==8.1.8
  • cloudpickle ==2.2.1
  • comet-ml ==3.33.3
  • confection ==0.1.5
  • configobj ==5.0.9
  • contourpy ==1.2.1
  • crcmod ==1.7
  • cycler ==0.12.1
  • cymem ==2.0.11
  • datasets ==3.5.0
  • dill ==0.3.8
  • disutils ==1.4.32.post2
  • docopt ==0.6.2
  • dulwich ==0.22.8
  • everett ==3.1.0
  • fastavro ==1.10.0
  • fasteners ==0.19
  • filelock ==3.18.0
  • fonttools ==4.57.0
  • frozenlist ==1.5.0
  • fsspec ==2024.12.0
  • gensim ==4.3.1
  • grpcio ==1.71.0
  • hdfs ==2.7.3
  • httplib2 ==0.20.4
  • huggingface-hub ==0.30.2
  • humanize ==4.7.0
  • idna ==3.10
  • joblib ==1.4.2
  • jsonschema ==4.23.0
  • jsonschema-specifications ==2024.10.1
  • kiwisolver ==1.4.8
  • langcodes ==3.5.0
  • language_data ==1.3.0
  • marisa-trie ==1.2.1
  • markdown-it-py ==3.0.0
  • matplotlib ==3.6.3
  • mdurl ==0.1.2
  • mpmath ==1.3.0
  • multidict ==6.4.3
  • multiprocess ==0.70.14
  • murmurhash ==1.0.12
  • networkx ==3.4.2
  • nltk ==3.8.1
  • numpy ==1.22.4
  • nvidia-cublas-cu12 ==12.4.5.8
  • nvidia-cuda-cupti-cu12 ==12.4.127
  • nvidia-cuda-nvrtc-cu12 ==12.4.127
  • nvidia-cuda-runtime-cu12 ==12.4.127
  • nvidia-cudnn-cu12 ==9.1.0.70
  • nvidia-cufft-cu12 ==11.2.1.3
  • nvidia-curand-cu12 ==10.3.5.147
  • nvidia-cusolver-cu12 ==11.6.1.9
  • nvidia-cusparse-cu12 ==12.3.1.170
  • nvidia-cusparselt-cu12 ==0.6.2
  • nvidia-nccl-cu12 ==2.21.5
  • nvidia-nvjitlink-cu12 ==12.4.127
  • nvidia-nvtx-cu12 ==12.4.127
  • objsize ==0.5.2
  • orjson ==3.10.16
  • packaging ==24.2
  • pandas ==2.2.3
  • pathlib_abc ==0.1.1
  • pathy ==0.11.0
  • pause ==0.3
  • pillow ==11.1.0
  • preshed ==3.0.9
  • propcache ==0.3.1
  • proto-plus ==1.26.1
  • protobuf ==3.20.0
  • psutil ==5.9.4
  • pyarrow ==19.0.1
  • pydantic ==1.10.21
  • pydot ==1.4.2
  • pymongo ==3.13.0
  • pynvml ==11.4.1
  • pyparsing ==3.2.3
  • python-box ==6.1.0
  • python-dateutil ==2.9.0.post0
  • pytz ==2025.2
  • referencing ==0.36.2
  • regex ==2024.11.6
  • requests ==2.32.3
  • requests-toolbelt ==1.0.0
  • responses ==0.18.0
  • rich ==14.0.0
  • rpds-py ==0.24.0
  • safetensors ==0.5.3
  • scikit-learn ==1.2.1
  • scipy ==1.12.0
  • semantic-version ==2.10.0
  • sentencepiece ==0.2.0
  • sentry-sdk ==2.25.1
  • simplejson ==3.20.1
  • six ==1.16.0
  • smart-open ==6.4.0
  • spacy ==3.4.4
  • spacy-legacy ==3.0.12
  • spacy-loggers ==1.0.5
  • srsly ==2.5.1
  • sympy ==1.13.1
  • thinc ==8.1.12
  • threadpoolctl ==3.6.0
  • tokenizers ==0.13.3
  • torch ==1.13.1
  • tqdm ==4.67.1
  • transformers ==4.25.1
  • triton ==3.2.0
  • typer ==0.7.0
  • typing_extensions ==4.13.2
  • tzdata ==2025.2
  • urllib3 ==1.26.20
  • wasabi ==0.10.1
  • websocket-client ==1.3.3
  • wrapt ==1.17.2
  • wurlitzer ==3.1.1
  • xxhash ==3.5.0
  • yarl ==1.19.0
  • zstandard ==0.23.0