gradiend

GRADIEND: Monosemantic Feature Learning within Neural Networks Applied to Gender Debiasing of Transformer Models

https://github.com/aieng-lab/gradiend

Science Score: 75.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
    Found 2 DOI reference(s) in README
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
    Organization aieng-lab has institutional domain (www.fim.uni-passau.de)
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (9.0%) to scientific vocabulary
Last synced: 6 months ago · JSON representation ·

Repository

GRADIEND: Monosemantic Feature Learning within Neural Networks Applied to Gender Debiasing of Transformer Models

Basic Info
  • Host: GitHub
  • Owner: aieng-lab
  • License: apache-2.0
  • Language: Jupyter Notebook
  • Default Branch: main
  • Size: 1.55 MB
Statistics
  • Stars: 1
  • Watchers: 0
  • Forks: 1
  • Open Issues: 0
  • Releases: 0
Created about 1 year ago · Last pushed 9 months ago
Metadata Files
Readme License Citation

README.md

GRADIEND: Monosemantic Feature Learning within Neural Networks Applied to Gender Debiasing of Transformer Models

Jonathan Drechsel, Steffen Herbold arXiv

This repository contains the official source code for the training and evaluation of GRADIEND: Monosemantic Feature Learning within Neural Networks Applied to Gender Debiasing of Transformer Models. Further evaluations of this study can be reproduced using our expanded version of bias-bench.

Quick Links

Install

bash git clone https://github.com/aieng-lab/gradiend.git cd gradiend conda env create --file environment.yml conda activate gradiend

Download Gendered Words and copy the file into the data/ directory of this repository.

Optional: Install aieng-lab/bias-bench for further evaluations and comparison to other debiasing techniques.

In order to use Llama-based models, you must first accept the Llama 3.2 Community License Agreement (see e.g., here). Further, you need to export a variable HF_TOKEN with a HF access token associated to your HF account (alternatively, but not recommended, you could insert your HF token in gradiend/model.py#HF_TOKEN).

Overview

Package | Description --------|------------ gradiend.model | GRADIEND model implementation gradiend.data | Data generation and access gradiend.training | Training of GRADIEND gradiend.evaluation | Evaluation of GRADIEND gradiend.export | Export functions for results, e.g., printing LaTeX tables and plotting images

NOTE: All python files of this repository should be called from the root directory of the project to ensure that the correct (relative) paths are used (e.g., python gradiend/training/gradiend_training.py).

See demo.ipynb for a quick overview of the GRADIEND model and the evaluation process.

Data

The gradiend.data package provides two purposes: - Data access: The relevant datasets can be accessed via the read_[dataset]() functions, i.e., read_genter(), read_geneutral(), read_namexact(), read_namextend(), and read_gentypes(). - Data generation: The generation process of these datasets is not necessary for the GRADIEND training (as the datasets are already generated), but the code is still available in the data package (see below Dataset Generation).

Training

The training of the GRADIEND models is done by running the gradiend.training.gradiend_training script, which will train three GRADIENDs for each considered base model (bert-base-cased, bert-large-cased, distilbert-base-cased, roberta-large), selecting the best model at the end. Intermediate results are saved in results/experiments/gradiend, and the final models are saved in results/models. The gradiend_training script relies on: - gradiend.training.data: the TrainingDataset class combines several datasets (e.g., GENTER, NAMEXACT, ...) and contains the logic to create appropriate training data during the training, i.e., matching a GENTER template sentence with a name of a certain gender and computing the tokens. - gradiend.training.trainer: the train() function trains a single GRADIEND model and provides many hyperparameters

Evaluation

Analysis of Encoder

The gradiend.evaluation.analyze_encoder.analyze() function analyzes the encoder of a trained GRADIEND model with three dataset:

  • GENTER as in the training process
  • GENTER with correctly filled template tokens, and with masked tokens that are gender-neutral
  • GENEUTRAL

This function can be easily called for multiple models by calling gradiend.evaluation.analyze_encoder.analyze_models(*models). The raw results are saved in the same base folder as the GRADIEND model (e.g., results/models/bert-base-cased_params_spl_test.csv). Then, the model metrics can be generated and printed by calling gradiend.evaluation.analyze_encoder.print_all_models().

Analysis of Decoder and Generation of (De-)Biased Models

gradiend.evaluation.analyze_decoder.default_evaluation() evaluates the decoder of a trained GRADIEND model by generating debiased models for different learning rates and gender factors. The evaluation results are cached per learning rate and gender factor (results/cache/decoder), and plots are shown visualizing the results.

The best debiased, male-biased, and female-biased models according to this evaluation can be generated by executing the gradiend.evaluation.select_models script, which saves these models into results/changed_models. The models are names [base model]-[type], with type being N for the debiased model, F for the female model, and M for the male model.

Some basic evaluations of these debiased models can be done by calling: - gradiend.analyze_decoder.evaluate_all_gender_predictions() and gradiend.export.gender_predictions.py for an overfitting analysis - gradiend.export.example_predictions.py to generate example predictions

Evaluation of (De-)Biased Models

See bias-bench for a comparison of the (de-)biased models generated with GRADIEND to other debiasing techniques.

Export

The export package contains functions to export the results of the evaluations, e.g., to print LaTeX tables or to plot images.

Script | Description -------|------------ dataset_stats | prints the statistics of the datasets used in the paper encoder_plot | Plots a violin plot regarding the distribution of encoded values of the encoder analysis changed_model_selection | Generates a table with the statistics of the selected (de-) biased models (from gradiend.evaluation.analyze_decoder.default_evaluation() gender_predictions | Plots predicted female and male probabilities for simple masking task to evaluate overfitting example_predictions | Generates example predictions for the selected (de-) biased as a LaTeX table

NOTE: To enable LaTeX plotting with your desired font, you need to adjust the init_matplotlib() function default arguments in the gradiend.util.py` file.

Dataset Generation

Although the experiments mentioned above are based on data published on Hugging Face by now, we also provide the code to generate the datasets used in the paper.

Required Datasets

If you want to re-create the datasets generated in the paper, you first need to download the following datasets:

Dataset | Download Link | Notes | Download Directory --------|---------------|--------------------------------------------------|------------------- Gender by Name | Download | Required for the generation of the name datasets | data/

Dataset Generation

The following scripts will generate the datasets used in the paper:

Dataset | Generation Script --------|------------------ GENTER | gradiend.data.filtering.generate_genter() GENEUTRAL | gradiend.data.generate_geneutral() NAMEXACT | gradiend.data.generate_namexact() NAMEXTEND | gradiend.data.generate_namextend()

Citation

@misc{drechsel2025gradiendmonosemanticfeaturelearning, title={{GRADIEND}: Monosemantic Feature Learning within Neural Networks Applied to Gender Debiasing of Transformer Models}, author={Jonathan Drechsel and Steffen Herbold}, year={2025}, eprint={2502.01406}, archivePrefix={arXiv}, primaryClass={cs.LG}, url={https://arxiv.org/abs/2502.01406}, }

Owner

  • Name: aieng-lab
  • Login: aieng-lab
  • Kind: organization

GitHub organization of the Chair for AI Engineering of the University of Passau

Citation (CITATION.cff)

cff-version: 1.2.0
message: If you use this software, please cite both the article from preferred-citation and the software itself.
authors:
  - family-names: Jonathan Drechsel
    given-names: Steffen Herbold
title: 'GRADIEND: Monosemantic Feature Learning within Neural Networks Applied to Gender Debiasing of Transformer Models'
version: 1.0.0
url: https://arxiv.org/abs/2502.01406
date-released: '2025-02-04'
preferred-citation:
  authors:
    - family-names: Jonathan Drechsel
      given-names: Steffen Herbold
  title: 'GRADIEND: Monosemantic Feature Learning within Neural Networks Applied to Gender Debiasing of Transformer Models'
  url: https://arxiv.org/abs/2502.01406
  type: generic
  year: '2025'
  conference: {}
  publisher: {}

GitHub Events

Total
  • Watch event: 1
  • Public event: 1
  • Push event: 12
  • Fork event: 1
Last Year
  • Watch event: 1
  • Public event: 1
  • Push event: 12
  • Fork event: 1

Dependencies

requirements.txt pypi
  • PySide6 ==6.7.3
  • Pympler ==1.1
  • hpack ==4.0.0
  • matplotlib ==3.9.2
  • munkres ==1.1.4
  • pyarrow ==17.0.0
  • sentencepiece ==0.2.0
  • shiboken6 ==6.7.3
  • torch ==2.4.1
  • torchaudio ==2.4.1
  • torchvision ==0.19.1
  • triton ==3.0.0
  • zstandard ==0.23.0
environment.yml conda
  • _libgcc_mutex 0.1
  • _openmp_mutex 4.5
  • absl-py 2.1.0
  • accelerate 1.0.0
  • aiohappyeyeballs 2.4.3
  • aiohttp 3.10.9
  • aiosignal 1.3.1
  • alembic 1.13.3
  • alsa-lib 1.2.12
  • anyio 4.6.0
  • aom 3.6.1
  • argon2-cffi 23.1.0
  • argon2-cffi-bindings 21.2.0
  • arrow 1.3.0
  • asttokens 2.4.1
  • async-lru 2.0.4
  • async-timeout 4.0.3
  • attrs 24.2.0
  • aws-c-auth 0.7.31
  • aws-c-cal 0.7.4
  • aws-c-common 0.9.28
  • aws-c-compression 0.2.19
  • aws-c-event-stream 0.4.3
  • aws-c-http 0.8.10
  • aws-c-io 0.14.18
  • aws-c-mqtt 0.10.7
  • aws-c-s3 0.6.6
  • aws-c-sdkutils 0.1.19
  • aws-checksums 0.1.20
  • aws-crt-cpp 0.28.3
  • aws-sdk-cpp 1.11.407
  • azure-core-cpp 1.13.0
  • azure-identity-cpp 1.9.0
  • azure-storage-blobs-cpp 12.13.0
  • azure-storage-common-cpp 12.8.0
  • azure-storage-files-datalake-cpp 12.12.0
  • babel 2.14.0
  • beautifulsoup4 4.12.3
  • blas 1.0
  • bleach 6.1.0
  • brotli 1.1.0
  • brotli-bin 1.1.0
  • brotli-python 1.1.0
  • bzip2 1.0.8
  • c-ares 1.34.1
  • ca-certificates 2024.12.14
  • cached-property 1.5.2
  • cached_property 1.5.2
  • cairo 1.18.0
  • certifi 2024.12.14
  • cffi 1.17.1
  • charset-normalizer 3.4.0
  • click 8.1.7
  • colorama 0.4.6
  • colorlog 6.8.2
  • comm 0.2.2
  • contourpy 1.3.0
  • cuda-cudart 12.4.99
  • cuda-cudart_linux-64 12.4.99
  • cuda-cupti 12.4.127
  • cuda-libraries 12.4.0
  • cuda-nvrtc 12.4.99
  • cuda-nvtx 12.4.127
  • cuda-opencl 12.4.99
  • cuda-runtime 12.4.0
  • cuda-version 12.4
  • cycler 0.12.1
  • cyrus-sasl 2.1.27
  • datasets 3.0.1
  • dbus 1.13.6
  • debugpy 1.8.6
  • decorator 5.1.1
  • defusedxml 0.7.1
  • dill 0.3.8
  • double-conversion 3.3.0
  • entrypoints 0.4
  • exceptiongroup 1.2.2
  • executing 2.1.0
  • expat 2.6.3
  • ffmpeg 4.4.2
  • filelock 3.16.1
  • font-ttf-dejavu-sans-mono 2.37
  • font-ttf-inconsolata 3.000
  • font-ttf-source-code-pro 2.038
  • font-ttf-ubuntu 0.83
  • fontconfig 2.14.2
  • fonts-conda-ecosystem 1
  • fonts-conda-forge 1
  • fonttools 4.54.1
  • fqdn 1.5.1
  • freetype 2.12.1
  • frozenlist 1.4.1
  • fsspec 2024.6.1
  • gettext 0.22.5
  • gettext-tools 0.22.5
  • gflags 2.2.2
  • glog 0.7.1
  • gmp 6.3.0
  • gmpy2 2.1.5
  • gnutls 3.7.9
  • graphite2 1.3.13
  • greenlet 3.1.1
  • grpcio 1.65.5
  • h11 0.14.0
  • h2 4.1.0
  • harfbuzz 9.0.0
  • hpack 4.0.0
  • httpcore 1.0.6
  • httpx 0.27.2
  • huggingface_hub 0.25.2
  • humanize
  • hyperframe 6.0.1
  • icu 75.1
  • idna 3.10
  • importlib-metadata 8.5.0
  • importlib-resources 6.4.5
  • importlib_metadata 8.5.0
  • importlib_resources 6.4.5
  • inflect 7.4.0
  • intel-openmp 2022.0.1
  • ipykernel 6.29.5
  • ipython 8.18.1
  • ipywidgets 8.1.5
  • isoduration 20.11.0
  • jedi 0.19.1
  • jinja2 3.1.4
  • joblib 1.4.2
  • json5 0.9.25
  • jsonpointer 3.0.0
  • jsonschema 4.23.0
  • jsonschema-specifications 2024.10.1
  • jsonschema-with-format-nongpl 4.23.0
  • jupyter 1.1.1
  • jupyter-lsp 2.2.5
  • jupyter_client 8.6.3
  • jupyter_console 6.6.3
  • jupyter_core 5.7.2
  • jupyter_events 0.10.0
  • jupyter_server 2.14.2
  • jupyter_server_terminals 0.5.3
  • jupyterlab 4.2.5
  • jupyterlab_pygments 0.3.0
  • jupyterlab_server 2.27.3
  • jupyterlab_widgets 3.0.13
  • keyutils 1.6.1
  • kiwisolver 1.4.7
  • krb5 1.21.3
  • lame 3.100
  • lcms2 2.16
  • ld_impl_linux-64 2.43
  • lerc 4.0.0
  • libabseil 20240722.0
  • libarrow 17.0.0
  • libarrow-acero 17.0.0
  • libarrow-dataset 17.0.0
  • libarrow-substrait 17.0.0
  • libasprintf 0.22.5
  • libasprintf-devel 0.22.5
  • libblas 3.9.0
  • libbrotlicommon 1.1.0
  • libbrotlidec 1.1.0
  • libbrotlienc 1.1.0
  • libcblas 3.9.0
  • libclang-cpp19.1 19.1.1
  • libclang13 19.1.1
  • libcrc32c 1.1.2
  • libcublas 12.4.2.65
  • libcufft 11.2.0.44
  • libcufile 1.9.0.20
  • libcups 2.3.3
  • libcurand 10.3.5.119
  • libcurl 8.10.1
  • libcusolver 11.6.0.99
  • libcusparse 12.3.0.142
  • libdeflate 1.22
  • libdrm 2.4.123
  • libedit 3.1.20191231
  • libegl 1.7.0
  • libev 4.33
  • libevent 2.1.12
  • libexpat 2.6.3
  • libffi 3.4.2
  • libgcc 14.1.0
  • libgcc-ng 14.1.0
  • libgettextpo 0.22.5
  • libgettextpo-devel 0.22.5
  • libgfortran 14.1.0
  • libgfortran-ng 14.1.0
  • libgfortran5 14.1.0
  • libgl 1.7.0
  • libglib 2.82.1
  • libglvnd 1.7.0
  • libglx 1.7.0
  • libgomp 14.1.0
  • libgoogle-cloud 2.29.0
  • libgoogle-cloud-storage 2.29.0
  • libgrpc 1.65.5
  • libiconv 1.17
  • libidn2 2.3.7
  • libjpeg-turbo 3.0.0
  • liblapack 3.9.0
  • libllvm19 19.1.1
  • libnghttp2 1.58.0
  • libnpp 12.2.5.2
  • libnsl 2.0.1
  • libntlm 1.4
  • libnvfatbin 12.4.99
  • libnvjitlink 12.4.99
  • libnvjpeg 12.3.1.89
  • libopengl 1.7.0
  • libparquet 17.0.0
  • libpciaccess 0.18
  • libpng 1.6.44
  • libpq 17.0
  • libprotobuf 5.27.5
  • libre2-11 2023.11.01
  • libsodium 1.0.20
  • libsqlite 3.46.1
  • libssh2 1.11.0
  • libstdcxx 14.1.0
  • libstdcxx-ng 14.1.0
  • libtasn1 4.19.0
  • libthrift 0.21.0
  • libtiff 4.7.0
  • libunistring 0.9.10
  • libutf8proc 2.8.0
  • libuuid 2.38.1
  • libva 2.22.0
  • libvpx 1.13.1
  • libwebp-base 1.4.0
  • libxcb 1.17.0
  • libxcrypt 4.4.36
  • libxkbcommon 1.7.0
  • libxml2 2.12.7
  • libxslt 1.1.39
  • libzlib 1.3.1
  • lightning-utilities 0.11.9
  • llvm-openmp 15.0.7
  • lz4-c 1.9.4
  • mako 1.3.5
  • markdown 3.6
  • markupsafe 3.0.1
  • matplotlib 3.9.2
  • matplotlib-base 3.9.2
  • matplotlib-inline 0.1.7
  • mistune 3.0.2
  • mkl 2022.1.0
  • more-itertools 10.5.0
  • mpc 1.3.1
  • mpfr 4.2.1
  • mpmath 1.3.0
  • multidict 6.1.0
  • multiprocess 0.70.16
  • munkres 1.1.4
  • mysql-common 9.0.1
  • mysql-libs 9.0.1
  • nbclient 0.10.0
  • nbconvert-core 7.16.4
  • nbformat 5.10.4
  • ncurses 6.5
  • nest-asyncio 1.6.0
  • nettle 3.9.1
  • networkx 3.2.1
  • nltk 3.9.1
  • notebook 7.2.2
  • notebook-shim 0.2.4
  • numpy 1.26.4
  • ocl-icd 2.3.2
  • openh264 2.3.1
  • openjpeg 2.5.2
  • openldap 2.6.8
  • openssl 3.4.0
  • optuna 4.0.0
  • orc 2.0.2
  • overrides 7.7.0
  • p11-kit 0.24.1
  • packaging 24.1
  • pandas 2.2.3
  • pandocfilters 1.5.0
  • parso 0.8.4
  • patsy 0.5.6
  • pcre2 10.44
  • pexpect 4.9.0
  • pickleshare 0.7.5
  • pillow 10.4.0
  • pip 24.2
  • pixman 0.43.2
  • pkgutil-resolve-name 1.3.10
  • platformdirs 4.3.6
  • plotly 5.24.1
  • prometheus_client 0.21.0
  • prompt-toolkit 3.0.48
  • prompt_toolkit 3.0.48
  • protobuf 5.27.5
  • psutil 6.0.0
  • pthread-stubs 0.4
  • ptyprocess 0.7.0
  • pure_eval 0.2.3
  • pyarrow 17.0.0
  • pyarrow-core 17.0.0
  • pycparser 2.22
  • pygments 2.18.0
  • pynvml 11.5.3
  • pyparsing 3.1.4
  • pyside6 6.7.3
  • pysocks 1.7.1
  • python 3.9.19
  • python-dateutil 2.9.0
  • python-fastjsonschema 2.20.0
  • python-json-logger 2.0.7
  • python-tzdata 2024.2
  • python-xxhash 3.5.0
  • python_abi 3.9
  • pytorch 2.4.1
  • pytorch-cuda 12.4
  • pytorch-lightning 2.4.0
  • pytorch-mutex 1.0
  • pytz 2024.1
  • pyyaml 6.0.2
  • pyzmq 26.2.0
  • qhull 2020.2
  • qt6-main 6.7.3
  • re2 2023.11.01
  • readline 8.2
  • referencing 0.35.1
  • regex 2024.9.11
  • requests 2.32.3
  • rfc3339-validator 0.1.4
  • rfc3986-validator 0.1.1
  • rpds-py 0.20.0
  • s2n 1.5.4
  • safetensors 0.4.5
  • scikit-learn 1.5.2
  • scipy 1.13.1
  • seaborn 0.13.2
  • seaborn-base 0.13.2
  • send2trash 1.8.3
  • setuptools 75.1.0
  • six 1.16.0
  • snappy 1.2.1
  • sniffio 1.3.1
  • soupsieve 2.5
  • sqlalchemy 2.0.35
  • stack_data 0.6.2
  • statsmodels 0.14.4
  • svt-av1 1.4.1
  • sympy 1.13.3
  • tabulate 0.9.0
  • tenacity 9.0.0
  • tensorboard 2.18.0
  • tensorboard-data-server 0.7.0
  • terminado 0.18.1
  • threadpoolctl 3.5.0
  • tinycss2 1.3.0
  • tk 8.6.13
  • tokenizers 0.20.0
  • tomli 2.0.2
  • torchaudio 2.4.1
  • torchmetrics 1.5.2
  • torchtriton 3.0.0
  • torchvision 0.19.1
  • tornado 6.4.1
  • tqdm 4.66.5
  • traitlets 5.14.3
  • transformers 4.45.2
  • typeguard 4.3.0
  • types-python-dateutil 2.9.0.20241003
  • typing-extensions 4.12.2
  • typing_extensions 4.12.2
  • typing_utils 0.1.0
  • tzdata 2024b
  • unicodedata2 15.1.0
  • uri-template 1.3.0
  • urllib3 2.2.3
  • wayland 1.23.1
  • wayland-protocols 1.37
  • wcwidth 0.2.13
  • webcolors 24.8.0
  • webencodings 0.5.1
  • websocket-client 1.8.0
  • werkzeug 3.0.4
  • wheel 0.44.0
  • widgetsnbextension 4.0.13
  • x264 1!164.3095
  • x265 3.5
  • xcb-util 0.4.1
  • xcb-util-cursor 0.1.5
  • xcb-util-image 0.4.0
  • xcb-util-keysyms 0.4.1
  • xcb-util-renderutil 0.3.10
  • xcb-util-wm 0.4.2
  • xkeyboard-config 2.43
  • xorg-libice 1.1.1
  • xorg-libsm 1.2.4
  • xorg-libx11 1.8.10
  • xorg-libxau 1.0.11
  • xorg-libxcomposite 0.4.6
  • xorg-libxcursor 1.2.2
  • xorg-libxdamage 1.1.6
  • xorg-libxdmcp 1.1.5
  • xorg-libxext 1.3.6
  • xorg-libxfixes 6.0.1
  • xorg-libxi 1.8.2
  • xorg-libxrandr 1.5.4
  • xorg-libxrender 0.9.11
  • xorg-libxtst 1.2.5
  • xorg-libxxf86vm 1.1.5
  • xorg-xorgproto 2024.1
  • xxhash 0.8.2
  • xz 5.2.6
  • yaml 0.2.5
  • yarl 1.13.1
  • zeromq 4.3.5
  • zipp 3.20.2
  • zlib 1.3.1
  • zstandard 0.23.0
  • zstd 1.5.6