impartial-multi-task-learning
PyTorch implementation of "Towards Impartial Multi-Task Learning"
https://github.com/johnlamaster/impartial-multi-task-learning
Science Score: 44.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
○Academic publication links
-
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (8.7%) to scientific vocabulary
Repository
PyTorch implementation of "Towards Impartial Multi-Task Learning"
Basic Info
- Host: GitHub
- Owner: JohnLaMaster
- License: gpl-3.0
- Language: Python
- Default Branch: main
- Size: 36.1 KB
Statistics
- Stars: 12
- Watchers: 1
- Forks: 0
- Open Issues: 1
- Releases: 1
Metadata Files
README.md
Impartial Multi-Task Learning
PyTorch implementation of "Towards Impartial Multi-Task Learning"
"Towards Impartial Multi-task Learning". Liyang Liu, Yi Li, Zhanghui Kuang, Jing-Hao Xue, Yimin Chen, Wenming Yang, Qingmin Liao, Wayne Zhang
OpenReview: https://openreview.net/forum?id=IMPnRXEWpvr
Source code written by: Ing. John T LaMaster
Implementation
- Instantiate the module and send to the GPU. As described in the paper, the possible methods are 'gradient', 'loss', and 'hybrid'.
Use the function "itertools.chain()" to combine the parameters of the NN and the IMTL module when defining the optimizer
num_losses = 5 # for example init_values = None # or a list of initial scaling values self.IMTL = IMTL(method='hybrid', num_losses=num_losses, init_loss=init_values).to(device) if self.opt.IMTL: parameters = itertools.chain(self.network.parameters(), self.IMTL.parameters()) else: self.network.parameters() self.optimizer = torch.optim.Adam(parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2))You will need to modify the NN code to return the intermediate feature z. This is the output of the encoder before the global pooling, flattening, and linear layers.
During training, append each task,
, to a list loss. Note: There are circumstances in which not all losses should be included in IMTL. For such cases, these values can be popped from the loss list before calling IMTL. As shown below, do not forget to call their backward call.
To evaluate the effect of each objective function, I often turn some off. To handle this, the IMTL code will only use tensors with requires_grad=True. These values do NOT need to be removed. The list and the intermediate feature z can now be used to call self.IMTL().
# For tracking the loss and not using IMTL ind = [] for i, cond, l in enumerate(zip(exclude,loss)): if cond: other_loss += l ind.append(i) for i in reverse(ind): loss.pop(i) # For using IMTL if self.IMTL and self.opt.phase=='train': # shared, specific = grad_loss, scaled_losses grad_loss, scaled_losses = self.IMTL.backward(shared_parameters=[model.parameters()], losses=loss) other_loss.backward() # If calculating losses that bypass IMTLAs long as IMTL.backwards() is used instead of the forward pass, nothing needs to be done to shared or specific as their .backward() calls are used in self.IMTL()
Finally, the optimizer can make it's forward step.
self.optimizer.step() self.optimizer.zero_grad()
Owner
- Name: John LaMaster
- Login: JohnLaMaster
- Kind: user
- Location: Munich, Germany
- Company: Technical Univeristy Munich
- Website: http://campar.in.tum.de/Main/JohnLaMaster
- Repositories: 2
- Profile: https://github.com/JohnLaMaster
I am a Biomedical Engineer currently working on my doctorate in Medical Imaging-Data Analysis using Deep Learning.
Citation (CITATION.cff)
cff-version: 1.0.0
message: "If you use this software, please cite it as below."
authors:
- family-names: LaMaster
given-names: John
orcid: https://orcid.org/0000-0002-2149-771X
title: "PyTorch Implementation of 'Towards Impartial Multi-Task Learning'"
version: 1.0.0
doi: 10.5281/zenodo.6794042
date-released: 2022-04-29
GitHub Events
Total
- Watch event: 3
Last Year
- Watch event: 3