https://github.com/astrazeneca/subtab

The official implementation of the paper, "SubTab: Subsetting Features of Tabular Data for Self-Supervised Representation Learning"

https://github.com/astrazeneca/subtab

Science Score: 10.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
  • .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Committers with academic emails
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (12.7%) to scientific vocabulary

Keywords

contrastive-learning multi-view-learning representation-learning self-supervised-learning tabular-data
Last synced: 5 months ago · JSON representation

Repository

The official implementation of the paper, "SubTab: Subsetting Features of Tabular Data for Self-Supervised Representation Learning"

Basic Info
  • Host: GitHub
  • Owner: AstraZeneca
  • License: apache-2.0
  • Language: Python
  • Default Branch: main
  • Homepage:
  • Size: 42.6 MB
Statistics
  • Stars: 144
  • Watchers: 2
  • Forks: 19
  • Open Issues: 1
  • Releases: 0
Topics
contrastive-learning multi-view-learning representation-learning self-supervised-learning tabular-data
Created over 4 years ago · Last pushed over 3 years ago
Metadata Files
Readme License

README.md

SubTab:

Author: Talip Ucar (ucabtuc@gmail.com)

The official implementation of the paper,

SubTab: Subsetting Features of Tabular Data for Self-Supervised Representation Learning

PWC

:largeorangediamond: Note: The extended version of SubTab with codes and pre-processed data for Adult Income and BlogFeedback datasets can be found at: https://github.com/talipucar/SubTab_extended

Table of Contents:

  1. Model
  2. Environment
  3. Data
  4. Configuration
  5. Training and Evaluation
  6. Adding New Datasets
  7. Results
  8. Experiment tracking
  9. Citing the paper
  10. Citing this repo

NeurIPS 2021 slides | NeurIPS 2021 poster :-------------------------:|:-------------------------: NeurIPS 2021 slides | NeurIPS 2021 poster

Model

SubTab

Click for a slower version of the animation ![SubTab](./assets/SubTab_slow.gif)

Environment

We used Python 3.7 for our experiments. The environment can be set up by following three steps:

pip install pipenv # To install pipenv if you don't have it already pipenv install --skip-lock # To install required packages. pipenv shell # To activate virtual env

If the second step results in issues, you can install packages in Pipfile individually by using pip i.e. "pip install package_name".

Data

MNIST dataset is already provided to demo the framework. For your own dataset, follow the instructions in Adding New Datasets.

Configuration

There are two types of configuration files: 1. runtime.yaml 2. mnist.yaml

  1. runtime.yaml is a high-level configuration file used by all datasets to:
  • define the random seed
  • turn on/off mlflow (Default: False)
  • turn on/off python profiler (Default: False)
  • set data directory
  • set results directory
  1. Second configuration file is dataset-specific and is used to configure the architecture of the model, loss functions, and so on.
  • For example, we set up a configuration file for MNIST dataset with the same name. Please note that the name of the configuration file should be same as name of the dataset with all letters in lowercase.
  • We can have configuration files for other datasets such as tcga.yaml and income.yaml for tcga and income datasets respectively.

Training and Evaluation

You can train and evaluate the model by using:

python train.py # For training. python eval.py # For evaluation

  • train.py will also run evaluation at the end of the training.
  • You can also run evaluation separately by using eval.py.
  • For a list of arguments, please see ./utils/arguments.py
    • Use -h argument to get help when running scripts.
    • Use -d dataset_name to run scripts on new datasets

Adding New Datasets

For each new dataset, you can use the following steps:

  1. Provide a _load_dataset_name() function, similar to MNIST load function
  • For example, you can add _load_tcga() for tcga dataset, or _load_income() for income dataset.
  • The function should return (xtrain, ytrain, xtest, ytest)
  1. Add a separate elif condition in this section within _load_data() method of TabularDataset() class in utils/load_data.py

  2. Create a new config file with the same name as dataset name.

    • For example, tcga.yaml for tcga dataset, or income.yaml for income dataset.
    • You can also duplicate one of the existing configuration files (e.g. mnist.yaml), and re-name it.
  • Make sure that the new config file is under config/ directory.
  1. Provide data folder with pre-processed training and test set, and place it under ./data/ directory. You can also do train-test split and pre-processing within your custom _load_dataset_name() function.

  2. (Optional) If you want to place the new dataset under a different directory than the local "./data/", then:

    • Place the dataset folder anywhere, and define the root directory to it in this line of /config/runtime.yaml.
  • For example, if the path to tcga dataset is /home/.../data/tcga/, you only need to include /home/.../data/ in runtime.yaml. The code will fill in tcga folder name from the name given in the command line argument (e.g. -d dataset_name. In this case, dataset_name would be tcga).

Structure of the repo

- train.py
- eval.py

- src
    |-model.py
    
- config
    |-runtime.yaml
    |-mnist.yaml
    
- utils
    |-load_data.py
    |-arguments.py
    |-model_utils.py
    |-loss_functions.py
    ...
    
- data
    |-mnist
    ...
    
- results
    |
    ...

Results

Results at the end of training is saved under ./results directory. Results directory structure is as following:

- results
    |-dataset name
            |-evaluation
                |-clusters (for plotting t-SNE and PCA plots of embeddings)
                |-reconstructions (not used)
            |-training
                |-model_mode (e.g. ae for autoencoder)   
                     |-model
                     |-plots
                     |-loss

You can save results of evaluations under "evaluation" folder.

Experiment tracking

MLFlow is used to track experiments. It is turned off by default, but can be turned on by changing option on this line in runtime config file in ./config/runtime.yaml

Citing the paper

@article{ucar2021subtab, title={SubTab: Subsetting Features of Tabular Data for Self-Supervised Representation Learning}, author={Ucar, Talip and Hajiramezanali, Ehsan and Edwards, Lindsay}, journal={Advances in Neural Information Processing Systems}, volume={34}, year={2021} }

Citing this repo

If you use SubTab framework in your own studies, and work, please cite it by using the following:

@Misc{talip_ucar_2021_SubTab, author = {Talip Ucar}, title = {{SubTab: Subsetting Features of Tabular Data for Self-Supervised Representation Learning}}, howpublished = {\url{https://github.com/AstraZeneca/SubTab}}, month = June, year = {since 2021} }

Owner

  • Name: AstraZeneca
  • Login: AstraZeneca
  • Kind: organization
  • Location: Global

Data and AI: Unlocking new science insights

GitHub Events

Total
  • Issues event: 1
  • Watch event: 4
  • Fork event: 3
Last Year
  • Issues event: 1
  • Watch event: 4
  • Fork event: 3

Committers

Last synced: 9 months ago

All Time
  • Total Commits: 18
  • Total Committers: 1
  • Avg Commits per committer: 18.0
  • Development Distribution Score (DDS): 0.0
Past Year
  • Commits: 0
  • Committers: 0
  • Avg Commits per committer: 0.0
  • Development Distribution Score (DDS): 0.0
Top Committers
Name Email Commits
Talip Uçar t****r@g****m 18

Dependencies

environment.yml conda
  • black
  • ipympl
  • isort
  • jupyter
  • jupyterlab
  • matplotlib >=3.1
  • nodejs
  • numpy >=1.16
  • pandas >=1
  • pip >=19
  • pytest
  • python 3.7.*
  • pyyaml
  • scipy
  • seaborn
  • tqdm
  • xlrd
Pipfile pypi
  • Pillow *
  • PyYAML *
  • datatable *
  • h5py *
  • imageio *
  • imbalanced-learn *
  • imblearn *
  • ipykernel *
  • isort *
  • jupyter *
  • matplotlib *
  • mlflow *
  • numpy *
  • opencv-python *
  • pandas *
  • pytest *
  • python-mnist *
  • pytorch-lightning *
  • scikit-image *
  • scikit-learn *
  • scipy *
  • seaborn *
  • sklearn *
  • tabulate *
  • tensorboard *
  • texttable *
  • torch *
  • torchvision *
  • tqdm *
  • wget *
  • xlrd >=1.0.0
requirements.txt pypi
  • Flask ==1.1.2
  • GitPython ==3.1.12
  • Jinja2 ==2.11.2
  • Keras ==2.4.3
  • Keras-Applications ==1.0.8
  • Keras-Preprocessing ==1.1.2
  • Mako ==1.1.4
  • Markdown ==3.3.3
  • MarkupSafe ==1.1.1
  • Pillow ==8.1.0
  • PyWavelets ==1.1.1
  • PyYAML ==5.4.1
  • Pygments ==2.7.4
  • QtPy ==1.9.0
  • SQLAlchemy ==1.3.22
  • Send2Trash ==1.5.0
  • Werkzeug ==1.0.1
  • absl-py ==0.11.0
  • aiohttp ==3.7.3
  • alembic ==1.4.1
  • argon2-cffi ==20.1.0
  • ase ==3.21.1
  • astor ==0.8.1
  • astunparse ==1.6.3
  • async-generator ==1.10
  • async-timeout ==3.0.1
  • attrs ==20.3.0
  • azure-core ==1.10.0
  • azure-storage-blob ==12.7.1
  • backcall ==0.2.0
  • bleach ==3.2.3
  • cached-property ==1.5.2
  • cachetools ==4.2.1
  • certifi ==2020.12.5
  • cffi ==1.14.4
  • chardet ==4.0.0
  • click ==7.1.2
  • cloudpickle ==1.6.0
  • cryptography ==3.3.1
  • cycler ==0.10.0
  • databricks-cli ==0.14.1
  • datatable ==0.11.1
  • decorator ==4.4.2
  • defusedxml ==0.6.0
  • docker ==4.4.1
  • entrypoints ==0.3
  • flatbuffers ==1.12
  • fsspec ==0.8.5
  • future ==0.18.2
  • gast ==0.3.3
  • gitdb ==4.0.5
  • google-auth ==1.24.0
  • google-auth-oauthlib ==0.4.2
  • google-pasta ==0.2.0
  • googledrivedownloader ==0.4
  • grpcio ==1.32.0
  • gunicorn ==20.0.4
  • h5py ==2.10.0
  • idna ==2.10
  • imageio ==2.9.0
  • imbalanced-learn ==0.8.0
  • imblearn ==0.0
  • importlib-metadata ==3.4.0
  • iniconfig ==1.1.1
  • ipykernel ==5.4.3
  • ipython ==7.19.0
  • ipython-genutils ==0.2.0
  • ipywidgets ==7.6.3
  • isodate ==0.6.0
  • isort ==5.7.0
  • itsdangerous ==1.1.0
  • jedi ==0.18.0
  • joblib ==1.0.0
  • jsonschema ==3.2.0
  • jupyter ==1.0.0
  • jupyter-client ==6.1.11
  • jupyter-console ==6.2.0
  • jupyter-core ==4.7.0
  • jupyterlab-pygments ==0.1.2
  • jupyterlab-widgets ==1.0.0
  • kiwisolver ==1.3.1
  • llvmlite ==0.35.0
  • matplotlib ==3.3.4
  • metis ==0.2a5
  • mistune ==0.8.4
  • mlflow ==1.13.1
  • msrest ==0.6.21
  • multidict ==5.1.0
  • nbclient ==0.5.1
  • nbconvert ==6.0.7
  • nbformat ==5.1.2
  • nest-asyncio ==1.5.1
  • networkx ==2.5
  • notebook ==6.2.0
  • numba ==0.52.0
  • numpy ==1.20.0
  • oauthlib ==3.1.0
  • opencv-python ==4.5.1.48
  • opt-einsum ==3.3.0
  • packaging ==20.9
  • pandas ==1.2.1
  • pandocfilters ==1.4.3
  • parso ==0.8.1
  • pexpect ==4.8.0
  • pickleshare ==0.7.5
  • pluggy ==0.13.1
  • prometheus-client ==0.9.0
  • prometheus-flask-exporter ==0.18.1
  • prompt-toolkit ==3.0.14
  • protobuf ==3.14.0
  • ptyprocess ==0.7.0
  • py ==1.10.0
  • pyasn1 ==0.4.8
  • pyasn1-modules ==0.2.8
  • pycparser ==2.20
  • pynndescent ==0.5.2
  • pyparsing ==2.4.7
  • pyrsistent ==0.17.3
  • pytest ==6.2.4
  • python-dateutil ==2.8.1
  • python-editor ==1.0.4
  • python-louvain ==0.15
  • python-mnist ==0.7
  • pytorch-lightning ==1.1.6
  • pytz ==2020.5
  • pyzmq ==22.0.2
  • qtconsole ==5.0.2
  • querystring-parser ==1.2.4
  • rdflib ==5.0.0
  • requests ==2.25.1
  • requests-oauthlib ==1.3.0
  • rsa ==4.7
  • scikit-image ==0.18.1
  • scikit-learn ==0.24.1
  • scipy ==1.6.0
  • seaborn ==0.11.1
  • six ==1.15.0
  • sklearn ==0.0
  • smmap ==3.0.5
  • sqlparse ==0.4.1
  • tabulate ==0.8.7
  • tensorboard ==2.5.0
  • tensorboard-data-server ==0.6.0
  • tensorboard-plugin-wit ==1.8.0
  • tensorflow ==2.4.1
  • tensorflow-estimator ==2.4.0
  • tensorflow-gpu ==1.15.3
  • termcolor ==1.1.0
  • terminado ==0.9.2
  • testpath ==0.4.4
  • texttable ==1.6.3
  • threadpoolctl ==2.1.0
  • tifffile ==2021.1.14
  • toml ==0.10.2
  • torch ==1.5.0
  • torch-geometric ==1.6.3
  • torchvision ==0.6.0
  • tornado ==6.1
  • tqdm ==4.56.0
  • traitlets ==5.0.5
  • typing-extensions ==3.7.4.3
  • umap-learn ==0.5.1
  • urllib3 ==1.26.3
  • wcwidth ==0.2.5
  • webencodings ==0.5.1
  • websocket-client ==0.57.0
  • wget ==3.2
  • widgetsnbextension ==3.5.1
  • wrapt ==1.12.1
  • xlrd ==2.0.1
  • yarl ==1.6.3
  • zipp ==3.4.0