transfertab

Transfer Learning for Tabular Data

https://github.com/manikyabard/transfertab

Science Score: 44.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
  • Committers with academic emails
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (7.6%) to scientific vocabulary

Keywords

deep-learning structured-data tabular-data transfer-learning
Last synced: 6 months ago · JSON representation ·

Repository

Transfer Learning for Tabular Data

Basic Info
Statistics
  • Stars: 3
  • Watchers: 1
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Topics
deep-learning structured-data tabular-data transfer-learning
Created almost 5 years ago · Last pushed about 4 years ago
Metadata Files
Readme Contributing License Citation

README.md

transfertab

Allow transfer learning using structured data.

Install

bash pip install transfertab

How to use

TransferTab enables effective transfer learning from models trained on tabular data.

To make use of transfertab, you'll need
* A pytorch model which contains some embeddings in a layer group.
* Another model to transfer these embeddings to, along with the metadata about the dataset on which this model will be trained.

The whole process takes place in two main steps-
1. Extraction
2. Transfer

Extraction

This involves storing the embeddings present in the model to a JSON structure. This JSON would contain the embeddings related to the categorical variables, and can be later transfered to another model which can also benefit from these categories. It will also be possible to have multiple JSON files constructed from various models with different categorical variables and then use them together.

Here we'll quickly construct a ModuleList with a bunch of Embedding layers, and see how to transfer it's embeddings.

emb_szs1 = ((3, 10), (2, 8)) emb_szs2 = ((2, 10), (2, 8))

embed1 = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs1]) embed2 = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs2])

embed1

ModuleList(
  (0): Embedding(3, 10)
  (1): Embedding(2, 8)
)

We can call the extractembeds function to extract the embeddings. Take a look at the documentation to see other dispatch methods, and details on the parameters.

df = pd.DataFrame({"old_cat1": [1, 2, 3, 4, 5], "old_cat2": ['a', 'b', 'b', 'b', 'a'], "old_cat3": ['A', 'B', 'B', 'B', 'A']}) cats = ("old_cat2", "old_cat3") embdict = extractembeds(embed2, df, transfercats=cats, allcats=cats)

embdict

{'old_cat2': {'classes': ['a', 'b'],
  'embeddings': [[-0.28762340545654297,
    -0.142189621925354,
    0.2027226686477661,
    1.1096185445785522,
    -0.4540262520313263,
    -1.346120834350586,
    0.048871781677007675,
    0.1740419715642929,
    0.002095407573506236,
    0.721653163433075],
   [-0.9072648882865906,
    2.674738645553589,
    -0.8560850024223328,
    -1.119917869567871,
    -0.19618849456310272,
    1.1431224346160889,
    -0.5177133679389954,
    -0.6497849822044373,
    -0.9011525511741638,
    0.9314191341400146]]},
 'old_cat3': {'classes': ['A', 'B'],
  'embeddings': [[2.5755045413970947,
    -1.3670053482055664,
    -0.3207620680332184,
    -1.1824427843093872,
    0.07631386071443558,
    0.501422107219696,
    0.8510317802429199,
    -0.6687257289886475],
   [-1.3658113479614258,
    -0.27968257665634155,
    0.26537612080574036,
    0.36773681640625,
    -0.9940593242645264,
    0.9408144354820251,
    0.5295664668083191,
    -0.5038257241249084]]}}

Transfer

The transfer process involves using the extracted weights, or a model directly and reusing trained paramters. We can define how this process will take place using the metadict which is a mapping of all the categories (in the current dataset), and contains information about the category it is mapped to (from the previous dataset which was used to train the old model), and how the new classes map to the old classes. We can even choose to map multiple classes to a single one, and in this case the aggfn parameter is used to aggregate the embedding vectors.

``` jsonfilepath = "../data/jsons/metadict.json"

with open(jsonfilepath, 'r') as j: metadict = json.loads(j.read()) ```

metadict

{'new_cat1': {'mapped_cat': 'old_cat2',
  'classes_info': {'new_class1': ['a', 'b'],
   'new_class2': ['b'],
   'new_class3': []}},
 'new_cat2': {'mapped_cat': 'old_cat3',
  'classes_info': {'new_class1': ['A'], 'new_class2': []}}}

We take a look at the layer parameters before and after transferring to see if it worked as expected.

embed1.state_dict()

OrderedDict([('0.weight',
              tensor([[-0.6940, -0.0337,  0.9491, -1.0520,  0.7804,  2.0246,  0.4242, -1.8351,
                        0.4660,  1.7667],
                      [-0.2802,  0.6081, -0.8459, -0.3288, -1.1264,  0.7621,  0.9347,  1.8096,
                       -0.1998, -0.2541],
                      [ 0.5706, -0.5213, -0.1398, -0.3742, -1.1951,  1.9640,  0.4132,  2.0365,
                        0.0655,  0.5189]])),
             ('1.weight',
              tensor([[ 0.9506, -0.0057,  0.2754,  0.8276,  0.8675,  1.2238, -1.5603,  1.0301],
                      [-0.7315, -0.3735,  0.6059,  0.2659, -0.4918,  1.5501,  0.0221, -0.6199]]))])

embed2.state_dict()

OrderedDict([('0.weight',
              tensor([[-2.8762e-01, -1.4219e-01,  2.0272e-01,  1.1096e+00, -4.5403e-01,
                       -1.3461e+00,  4.8872e-02,  1.7404e-01,  2.0954e-03,  7.2165e-01],
                      [-9.0726e-01,  2.6747e+00, -8.5609e-01, -1.1199e+00, -1.9619e-01,
                        1.1431e+00, -5.1771e-01, -6.4978e-01, -9.0115e-01,  9.3142e-01]])),
             ('1.weight',
              tensor([[ 2.5755, -1.3670, -0.3208, -1.1824,  0.0763,  0.5014,  0.8510, -0.6687],
                      [-1.3658, -0.2797,  0.2654,  0.3677, -0.9941,  0.9408,  0.5296, -0.5038]]))])

``` transfercats = ("newcat1", "newcat2") newcatcols = ("newcat1", "newcat2") oldcatcols = ("oldcat2", "old_cat3")

newcatdict = {"newcat1" : ["newclass1", "newclass2", "newclass3"], "newcat2" : ["newclass1", "newclass2"]} oldcatdict = {"oldcat2" : ["a", "b"], "old_cat3" : ["A", "B"]}

transferembeds(embed1, embdict, metadict, transfercats, newcatcols=newcatcols, oldcatcols=oldcatcols, newcatdict=newcatdict) ```

embed1.state_dict()

OrderedDict([('0.weight',
              tensor([[-0.5974,  1.2663, -0.3267, -0.0051, -0.3251, -0.1015, -0.2344, -0.2379,
                       -0.4495,  0.8265],
                      [-0.9073,  2.6747, -0.8561, -1.1199, -0.1962,  1.1431, -0.5177, -0.6498,
                       -0.9012,  0.9314],
                      [-0.5974,  1.2663, -0.3267, -0.0051, -0.3251, -0.1015, -0.2344, -0.2379,
                       -0.4495,  0.8265]])),
             ('1.weight',
              tensor([[ 2.5755, -1.3670, -0.3208, -1.1824,  0.0763,  0.5014,  0.8510, -0.6687],
                      [ 0.6048, -0.8233, -0.0277, -0.4074, -0.4589,  0.7211,  0.6903, -0.5863]]))])

As we can see, the embeddings have been transferred over.

Owner

  • Name: Manikya Bardhan
  • Login: manikyabard
  • Kind: user

Citation (CITATION.cff)

cff-version: 1.1.0
message: If you use this software, please cite it as below.
authors:
  - family-names: Bardhan
    given-names: Manikya
  - family-names: "Rishon Manoj"
    given-names: Joe
  - family-names: Acharya
    given-names: Rakshith
  - family-names: Datta
    given-names: Ishita
title: transfertab
version: 1.0.0
repository-code: "https://github.com/manikyabard/transfertab"

GitHub Events

Total
Last Year

Committers

Last synced: almost 3 years ago

All Time
  • Total Commits: 39
  • Total Committers: 5
  • Avg Commits per committer: 7.8
  • Development Distribution Score (DDS): 0.256
Top Committers
Name Email Commits
Manikya m****d@g****m 29
RakshithRAcharya r****0@g****m 6
JoeRishon j****n@g****m 2
JoeRishon 4****n@u****m 1
Ishita Datta i****0@g****m 1

Issues and Pull Requests

Last synced: over 2 years ago

All Time
  • Total issues: 0
  • Total pull requests: 1
  • Average time to close issues: N/A
  • Average time to close pull requests: about 2 hours
  • Total issue authors: 0
  • Total pull request authors: 1
  • Average comments per issue: 0
  • Average comments per pull request: 1.0
  • Merged pull requests: 1
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 0
  • Pull requests: 0
  • Average time to close issues: N/A
  • Average time to close pull requests: N/A
  • Issue authors: 0
  • Pull request authors: 0
  • Average comments per issue: 0
  • Average comments per pull request: 0
  • Merged pull requests: 0
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
Pull Request Authors
  • RakshithRAcharya (1)
Top Labels
Issue Labels
Pull Request Labels

Packages

  • Total packages: 1
  • Total downloads:
    • pypi 4 last-month
  • Total dependent packages: 0
  • Total dependent repositories: 1
  • Total versions: 1
  • Total maintainers: 1
pypi.org: transfertab

A library to help transfer learn for structured data.

  • Versions: 1
  • Dependent Packages: 0
  • Dependent Repositories: 1
  • Downloads: 4 Last month
Rankings
Dependent packages count: 10.1%
Dependent repos count: 21.6%
Forks count: 29.8%
Stargazers count: 31.9%
Average: 34.5%
Downloads: 79.0%
Maintainers (1)
Last synced: 6 months ago

Dependencies

docs/Gemfile.lock rubygems
  • activesupport 6.0.3.4
  • addressable 2.7.0
  • coffee-script 2.4.1
  • coffee-script-source 1.11.1
  • colorator 1.1.0
  • commonmarker 0.17.13
  • concurrent-ruby 1.1.7
  • dnsruby 1.61.5
  • em-websocket 0.5.2
  • ethon 0.12.0
  • eventmachine 1.2.7
  • execjs 2.7.0
  • faraday 1.3.0
  • faraday-net_http 1.0.1
  • ffi 1.14.2
  • forwardable-extended 2.6.0
  • gemoji 3.0.1
  • github-pages 209
  • github-pages-health-check 1.16.1
  • html-pipeline 2.14.0
  • http_parser.rb 0.6.0
  • i18n 0.9.5
  • jekyll 3.9.0
  • jekyll-avatar 0.7.0
  • jekyll-coffeescript 1.1.1
  • jekyll-commonmark 1.3.1
  • jekyll-commonmark-ghpages 0.1.6
  • jekyll-default-layout 0.1.4
  • jekyll-feed 0.15.1
  • jekyll-gist 1.5.0
  • jekyll-github-metadata 2.13.0
  • jekyll-mentions 1.6.0
  • jekyll-optional-front-matter 0.3.2
  • jekyll-paginate 1.1.0
  • jekyll-readme-index 0.3.0
  • jekyll-redirect-from 0.16.0
  • jekyll-relative-links 0.6.1
  • jekyll-remote-theme 0.4.2
  • jekyll-sass-converter 1.5.2
  • jekyll-seo-tag 2.6.1
  • jekyll-sitemap 1.4.0
  • jekyll-swiss 1.0.0
  • jekyll-theme-architect 0.1.1
  • jekyll-theme-cayman 0.1.1
  • jekyll-theme-dinky 0.1.1
  • jekyll-theme-hacker 0.1.2
  • jekyll-theme-leap-day 0.1.1
  • jekyll-theme-merlot 0.1.1
  • jekyll-theme-midnight 0.1.1
  • jekyll-theme-minimal 0.1.1
  • jekyll-theme-modernist 0.1.1
  • jekyll-theme-primer 0.5.4
  • jekyll-theme-slate 0.1.1
  • jekyll-theme-tactile 0.1.1
  • jekyll-theme-time-machine 0.1.1
  • jekyll-titles-from-headings 0.5.3
  • jekyll-watch 2.2.1
  • jemoji 0.12.0
  • kramdown 2.3.0
  • kramdown-parser-gfm 1.1.0
  • liquid 4.0.3
  • listen 3.4.0
  • mercenary 0.3.6
  • mini_portile2 2.5.0
  • minima 2.5.1
  • minitest 5.14.3
  • multipart-post 2.1.1
  • nokogiri 1.11.0
  • octokit 4.20.0
  • pathutil 0.16.2
  • public_suffix 3.1.1
  • racc 1.5.2
  • rb-fsevent 0.10.4
  • rb-inotify 0.10.1
  • rexml 3.2.4
  • rouge 3.23.0
  • ruby-enum 0.8.0
  • ruby2_keywords 0.0.2
  • rubyzip 2.3.0
  • safe_yaml 1.0.5
  • sass 3.7.4
  • sass-listen 4.0.0
  • sawyer 0.8.2
  • simpleidn 0.1.1
  • terminal-table 1.8.0
  • thread_safe 0.3.6
  • typhoeus 1.4.0
  • tzinfo 1.2.9
  • unf 0.1.4
  • unf_ext 0.0.7.7
  • unicode-display_width 1.7.0
  • zeitwerk 2.4.2