multimodal-trajectory-modeling

Code supplement for "Unsupervised multimodal modeling of cognitive and brain health trajectories for early dementia prediction"

https://github.com/burkh4rt/multimodal-trajectory-modeling

Science Score: 67.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
    Found 5 DOI reference(s) in README
  • Academic publication links
    Links to: zenodo.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (7.9%) to scientific vocabulary

Keywords

brain-health dementia-detection mixture-models state-space-models unsupervised-machine-learning
Last synced: 6 months ago · JSON representation ·

Repository

Code supplement for "Unsupervised multimodal modeling of cognitive and brain health trajectories for early dementia prediction"

Basic Info
Statistics
  • Stars: 3
  • Watchers: 1
  • Forks: 1
  • Open Issues: 0
  • Releases: 1
Topics
brain-health dementia-detection mixture-models state-space-models unsupervised-machine-learning
Created almost 2 years ago · Last pushed almost 2 years ago
Metadata Files
Readme License Citation

README.md

PyPI version DOI

Multimodal Trajectory Modeling

This repository provides code and anonymized data to accompany our paper "Unsupervised multimodal modeling of cognitive and brain health trajectories for early dementia prediction" [^1]. In it, we propose and validate a mixture of state space models to perform unsupervised clustering of short trajectories. Within the state space framework, we let expensive-to-gather biomarkers correspond to hidden states and readily obtainable cognitive metrics correspond to measurements. Upon training with expectation maximization, we find that our clusters stratify persons according to clinical outcome. Furthermore, we can effectively predict on held-out trajectories using cognitive metrics alone. Our approach accommodates missing data through model marginalization and generalizes across research and clinical cohorts.

Data format

We consider a training dataset

$$ \mathcal{D} = {(x{1:T}^{i}, z{1:T}^{i}) }{1\leq i \leq nd} $$

consisting of $n_d$ sequences of states and observations paired in time. We denote the states $z_{1:T}^{i} = (z_1^i, z_2^i, \dotsc, z_T^i)$ where $z_t^i \in \mathbb{R}^d$ corresponds to the state at time $t$ for the $i$th instance and measurements $x_{1:T}^{i} = (x_1^i, x_2^i, \dotsc, x_T^i)$ where $x_t^i \in \mathbb{R}^\ell$ corresponds to the observation at time $t$ for the $i$th instance. For the purposes of this code, we adopt the convention that collections of time-delineated sequences of vectors will be stored as 3-tensors, where the first dimension spans time $1\leq t \leq T$, the second dimension spans instances $1\leq i \leq n_d$ (these will almost always correspond to an individual or participant), and the third dimension spans the components of each state or observation vector (and so will have dimension either $d$ or $\ell$). We accommodate trajectories of differing lengths by standardising to the longest available trajectory in a dataset and appending np.nan's to shorter trajectories.

Model specification

We adopt a mixture of state space models for the data:

$$ p(z^i{1:T}, x^i{1:T}) = \sum{c=1}^{nc} \pi{c} \delta{{c=c^i}} \big( p(z1^i| c) \prod{t=2}^T p(zt^i | z{t-1}^i, c) \prod{t=1}^T p(xt^i | z_t^i, c) \big) $$

Each individual $i$ is independently assigned to some cluster $c^i$ with probability $\pi_{c}$, and then conditional on this cluster assignment, their initial state $z_1^i$ is drawn according to $p(z_1^i| c)$, with each subsequent state $z_t^i, 2\leq t \leq T$ being drawn in turn using the cluster-specific state model $p(z_t^i | z_{t-1}^i, c)$, depending on the previous state. At each point in time, we obtain an observation $x_t^i$ from the cluster-specific measurement model $p(x_t^i | z_t^i, c)$, depending on the current state. In what follows, we assume both the state and measurement models are stationary for each cluster, i.e. they are independent of $t$. In particular, for a given individual, the relationship between the state and measurement should not change over time.

In our main framework, we additionally assume that the cluster-specific state initialisation is Gaussian, i.e. $p(z_1^i| c) = \eta_d(z_1^i; m_c, S_c)$, and the cluster-specific state and measurement models are linear Gaussian, i.e. $p(z_t^i | z_{t-1}^i, c) = \eta_d(z_t^i; z_{t-1}^iA_c, \Gamma_c)$ and $p(x_t^i | z_t^i, c) = \eta_\ell(x_t^i; z_t^iH_c, \Lambda_c)$, where $\eta_d(\cdot, \mu, \Sigma)$ denotes the multivariate $d$-dimensional Gaussian density with mean $\mu$ and covariance $\Sigma$, yielding:

$$ p(z^i{1:T}, x^i{1:T}) = \sum{c=1}^{nc} \pi{c} \delta{{c=c^i}} \big( \etad(z1^i; mc, Sc) \prod{t=2}^T \etad(zt^i; z{t-1}^iAc, \Gammac) \prod{t=1}^T \eta\ell(xt^i; zt^iHc, \Lambdac) \big). $$

In particular, we assume that the variables we are modeling are continuous and changing over time. When we train a model like the above, we take a dataset $\mathcal{D}$ and an arbitrary set of cluster assignments $c^i$ (as these are also latent/ hidden from us) and iteratively perform M and E steps (from which EM gets its name):

  • (E) Expectation step: given the current model, we assign each data instance $(z^i_{1:T}, x^i_{1:T})$ to the cluster to which it is mostly likely to belong under the current model
  • (M) Maximization step: given the current cluster assignments, we compute the sample-level cluster assignment probabilities (the $\pi_c$) and optimal cluster-specific parameters

Optimization completes after a fixed (large) number of steps or when no data instances change their cluster assignment at a given iteration.

The approach (in framework) uses mixtures the above modeling formulation for training and inference. A Bayesian version of this approach was pioneered by Chiappa and Barber[^2]. The EM[^3] approach we take with hard cluster assignment allows us to extend the above model to to nonlinear specifications (see framework_extended), with the tradeoff that certain types of model marginalisation are currently unsupported.

To run the code

Model training and predictions are done in Python. Statistical tests on the results are mostly performed in R (with survival models being a notable exception). An environment suitable for running this code can be created using either:

  1. a conda environment as described in environment.yml and an renv environment as in renv.lock

sh conda env create -f environment.yml conda activate hydrangea R -e "install.packages('remotes', repos = c(CRAN = 'https://cloud.r-project.org'))" R -e "remotes::install_github('rstudio/renv@v1.0.0')" R -e "renv::restore()"

or

  1. docker

sh docker build -t thistly-cross .

After choosing one of the above, on a machine with Make installed, and with all required data in the data folder, one can then

sh make all

to reproduce the figures and analyses related to ADNI. (Due to patient privacy, we provide the code used to produce the results for MACC here, but not the data.) If using docker, uncomment line #9 of the Makefile to define docker prior to running.

Adapting the code for your own use

We provide a publicly available package unsupervised-multimodal-trajectory-modeling on pypi. A template repository containing starter code is available here: https://github.com/burkh4rt/Unsupervised-Trajectory-Clustering-Starter

Caveats & Troubleshooting

Some efforts have been made to automatically handle edge cases. For a given training run, if any cluster becomes too small (fewer than 3 members), training terminates. In order to learn a model, we make assumptions about our training data as described above. While our approach seems to be robust to some types of model misspecification, we have encountered training issues with the following problems:

  1. Extreme outliers. An extreme outlier tends to want to form its own cluster (and that's problematic). In many cases this may be due to a typo or failed data-cleaning (i.e. an upstream problem). Generating histograms of each feature is one way to recognize this problem.
  2. Discrete / static features. Including discrete data violates our Gaussian assumptions. If we learn a cluster where each trajectory has the same value for one of the states or observations at a given time step, then we are prone to estimating a singular covariance structure for this cluster which yields numerical instabilities. Adding a small bit of noise to discrete features may remediate numerical instability to some extent.

[^1]: M. Burkhart, L. Lee, D. Vaghari, A. Toh, E. Chong, C. Chen, P. Tiňo, & Z. Kourtzi, Unsupervised multimodal modeling of cognitive and brain health trajectories for early dementia prediction, Scientific Reports 14 (2024)

[^2]: S. Chiappa & D. Barber, Dirichlet Mixtures of Bayesian Linear Gaussian State-Space Models: a Variational Approach, Tech. rep. 161, Max Planck Institute for Biological Cybernetics, 2007

[^3]: A. Dempster, N. Laird, & D. Rubin, Maximum Likelihood from
Incomplete Data via the EM Algorithm
, Journal of the Royal Statistical Society: Series B (Methodological) 39 (1977)

Owner

  • Name: Michael Burkhart
  • Login: burkh4rt
  • Kind: user
  • Company: University of Cambridge

research associate—machine learning for neuroscience

Citation (CITATION.cff)

cff-version: 1.2.0
message: "Please cite the following work when using this software."
license:
  - MIT
preferred-citation:
  type: article
  authors:
    - family-names: "Burkhart"
      given-names: "Michael C."
      orcid: "https://orcid.org/0000-0002-2772-5840"
    - family-names: "Lee"
      given-names: "Liz Y."
    - family-names: "Vaghari"
      given-names: "Delshad"
    - family-names: "Toh"
      given-names: "An Qi"
    - family-names: "Chong"
      given-names: "Eddie"
    - family-names: "Chen"
      given-names: "Christopher"
    - family-names: "Tiňo"
      given-names: "Peter"
    - family-names: "Kourtzi"
      given-names: "Zoe"
  doi: "10.1038/s41598-024-60914-w"
  journal: "Scientific Reports"
  title:
    "Unsupervised multimodal modeling of cognitive and brain health
    trajectories for early dementia prediction"
  volume: 14
  year: 2024

GitHub Events

Total
Last Year

Dependencies

Dockerfile docker
  • ubuntu 22.04 build
pyproject.toml pypi
requirements-docker.txt pypi
  • Babel ==2.14.0
  • Bottleneck ==1.3.8
  • Jinja2 ==3.1.3
  • MarkupSafe ==2.1.5
  • PyQt5-sip ==12.13.0
  • PySocks ==1.7.1
  • PyYAML ==6.0.1
  • Pygments ==2.17.2
  • QtPy ==2.4.1
  • Send2Trash ==1.8.3
  • aiofiles ==23.2.1
  • aiosqlite ==0.20.0
  • anyio ==4.3.0
  • appnope ==0.1.4
  • argon2-cffi ==23.1.0
  • argon2-cffi-bindings ==21.2.0
  • astor ==0.8.1
  • asttokens ==2.4.1
  • async-lru ==2.0.4
  • attrs ==23.2.0
  • autograd ==1.6.2
  • autograd-gamma ==0.5.0
  • backcall ==0.2.0
  • beautifulsoup4 ==4.12.3
  • black ==24.4.2
  • bleach ==6.1.0
  • brotlipy ==0.7.0
  • certifi ==2024.2.2
  • cffi ==1.16.0
  • charset-normalizer ==3.3.2
  • click ==8.1.7
  • comm ==0.2.2
  • contourpy ==1.2.1
  • cryptography ==42.0.5
  • cycler ==0.12.1
  • debugpy ==1.8.1
  • decorator ==5.1.1
  • defusedxml ==0.7.1
  • dill ==0.3.8
  • ecos ==2.0.13
  • entrypoints ==0.4
  • et-xmlfile ==1.1.0
  • executing ==2.0.1
  • fastjsonschema ==2.19.1
  • fonttools ==4.51.0
  • formulaic ==1.0.1
  • future ==1.0.0
  • idna ==3.7
  • importlib_metadata ==7.1.0
  • importlib_resources ==6.4.0
  • interface-meta ==1.3.0
  • ipykernel ==6.29.4
  • ipython ==8.24.0
  • ipython-genutils ==0.2.0
  • ipywidgets ==8.1.2
  • isort ==5.13.2
  • jedi ==0.19.1
  • joblib ==1.4.0
  • json5 ==0.9.25
  • jsonschema ==4.21.1
  • jsonschema-specifications ==2023.12.1
  • jupyter_client ==8.6.1
  • jupyter_core ==5.7.2
  • jupyterlab_widgets ==3.0.10
  • kiwisolver ==1.4.5
  • lifelines ==0.28.0
  • llvmlite ==0.42.0
  • lxml ==5.2.1
  • matplotlib ==3.8.4
  • matplotlib-inline ==0.1.7
  • mistune ==3.0.2
  • mpld3 ==0.5.10
  • multiprocess ==0.70.16
  • munkres ==1.1.4
  • mypy-extensions ==1.0.0
  • nest-asyncio ==1.6.0
  • numba ==0.59.1
  • numexpr ==2.10.0
  • numpy ==1.26.4
  • openpyxl ==3.1.2
  • osqp ==0.6.5
  • packaging ==24.0
  • pandas ==2.2.2
  • pandocfilters ==1.5.1
  • parso ==0.8.4
  • pathspec ==0.12.1
  • patsy ==0.5.6
  • pexpect ==4.9.0
  • pickleshare ==0.7.5
  • pillow ==10.3.0
  • pip ==24.0
  • platformdirs ==4.2.1
  • ply ==3.11
  • pox ==0.3.4
  • ppft ==1.7.6.8
  • prometheus_client ==0.20.0
  • prompt-toolkit ==3.0.43
  • psutil ==5.9.8
  • ptyprocess ==0.7.0
  • pure-eval ==0.2.2
  • pyOpenSSL ==24.1.0
  • pybind11 ==2.12.0
  • pycparser ==2.22
  • pyparsing ==3.1.2
  • pyrsistent ==0.20.0
  • python-dateutil ==2.9.0.post0
  • python-json-logger ==2.0.7
  • pytz ==2024.1
  • pyzmq ==26.0.2
  • qdldl ==0.1.7.post2
  • qtconsole ==5.5.1
  • referencing ==0.35.0
  • requests ==2.31.0
  • rfc3339-validator ==0.1.4
  • rfc3986-validator ==0.1.1
  • rpds-py ==0.18.0
  • scikit-learn ==1.4.2
  • scipy ==1.11.4
  • seaborn ==0.13.2
  • setuptools ==69.5.1
  • sip ==6.8.3
  • six ==1.16.0
  • sniffio ==1.3.1
  • soupsieve ==2.5
  • stack-data ==0.6.3
  • statsmodels ==0.14.2
  • tabulate ==0.9.0
  • terminado ==0.18.1
  • threadpoolctl ==3.4.0
  • tinycss2 ==1.3.0
  • toml ==0.10.2
  • tomli ==2.0.1
  • tornado ==6.4
  • tqdm ==4.66.2
  • traitlets ==5.14.3
  • typing_extensions ==4.11.0
  • tzdata ==2024.1
  • urllib3 ==2.2.1
  • wcwidth ==0.2.13
  • webencodings ==0.5.1
  • websocket-client ==1.8.0
  • wheel ==0.43.0
  • widgetsnbextension ==4.0.10
  • wrapt ==1.16.0
  • y-py ==0.6.2
  • ypy-websocket ==0.12.4
  • zipp ==3.18.1
environment.yml conda
  • aiofiles 22.1.0.*
  • aiosqlite 0.18.0.*
  • anyio 3.5.0.*
  • appnope 0.1.2.*
  • argon2-cffi 21.3.0.*
  • argon2-cffi-bindings 21.2.0.*
  • astor 0.8.1.*
  • asttokens 2.0.5.*
  • async-lru 2.0.4.*
  • attrs 23.1.0.*
  • autograd-gamma 0.5.0.*
  • babel 2.11.0.*
  • backcall 0.2.0.*
  • beautifulsoup4 4.12.2.*
  • black 23.11.0.*
  • blas 1.0.*
  • bleach 4.1.0.*
  • bottleneck 1.3.5.*
  • brotli 1.0.9.*
  • brotli-bin 1.0.9.*
  • brotlipy 0.7.0.*
  • bzip2 1.0.8.*
  • ca-certificates 2023.12.12.*
  • certifi 2024.2.2.*
  • cffi 1.15.1.*
  • charset-normalizer 2.0.4.*
  • click 8.1.7.*
  • comm 0.1.2.*
  • contourpy 1.0.5.*
  • cryptography 41.0.2.*
  • cycler 0.11.0.*
  • cyrus-sasl 2.1.28.*
  • debugpy 1.6.7.*
  • decorator 5.1.1.*
  • defusedxml 0.7.1.*
  • entrypoints 0.4.*
  • et_xmlfile 1.1.0.*
  • executing 0.8.3.*
  • fftw 3.3.9.*
  • fonttools 4.25.0.*
  • freetype 2.12.1.*
  • gettext 0.21.0.*
  • giflib 5.2.1.*
  • glib 2.69.1.*
  • gst-plugins-base 1.14.1.*
  • gstreamer 1.14.1.*
  • icu 58.2.*
  • idna 3.4.*
  • importlib-metadata 6.0.0.*
  • importlib_metadata 6.0.0.*
  • importlib_resources 5.2.0.*
  • intel-openmp 2021.4.0.*
  • interface_meta 1.3.0.*
  • ipykernel 6.25.0.*
  • ipython 8.3.0.*
  • ipython_genutils 0.2.0.*
  • ipywidgets 8.0.4.*
  • jedi 0.18.1.*
  • jinja2 3.1.2.*
  • joblib 1.2.0.*
  • jpeg 9e.*
  • json5 0.9.6.*
  • jsonschema 4.17.3.*
  • jupyter 1.0.0.*
  • jupyter-lsp 2.2.0.*
  • jupyter_client 7.4.9.*
  • jupyter_console 6.6.3.*
  • jupyter_core 5.3.0.*
  • jupyter_events 0.6.3.*
  • jupyter_server 1.23.4.*
  • jupyter_server_fileid 0.9.0.*
  • jupyter_server_terminals 0.4.4.*
  • jupyter_server_ydoc 0.8.0.*
  • jupyter_ydoc 0.2.4.*
  • jupyterlab 3.6.3.*
  • jupyterlab_pygments 0.1.2.*
  • jupyterlab_server 2.24.0.*
  • jupyterlab_widgets 3.0.5.*
  • kiwisolver 1.4.4.*
  • krb5 1.20.1.*
  • lcms2 2.12.*
  • lerc 3.0.*
  • libbrotlicommon 1.0.9.*
  • libbrotlidec 1.0.9.*
  • libbrotlienc 1.0.9.*
  • libclang 14.0.6.*
  • libclang13 14.0.6.*
  • libcxx 14.0.6.*
  • libdeflate 1.17.*
  • libedit 3.1.20221030.*
  • libffi 3.4.4.*
  • libgfortran 5.0.0.*
  • libgfortran5 11.3.0.*
  • libiconv 1.16.*
  • libllvm11 11.1.0.*
  • libllvm12 12.0.0.*
  • libllvm14 14.0.6.*
  • libpng 1.6.39.*
  • libpq 12.15.*
  • libsodium 1.0.18.*
  • libtiff 4.5.1.*
  • libwebp 1.3.2.*
  • libwebp-base 1.3.2.*
  • libxml2 2.10.4.*
  • libxslt 1.1.37.*
  • llvm-openmp 14.0.6.*
  • lxml 4.9.3.*
  • lz4-c 1.9.4.*
  • make 4.2.1.*
  • markupsafe 2.1.1.*
  • matplotlib 3.7.1.*
  • matplotlib-base 3.7.1.*
  • matplotlib-inline 0.1.6.*
  • mistune 0.8.4.*
  • mkl 2021.4.0.*
  • mkl-service 2.4.0.*
  • mkl_fft 1.3.1.*
  • mkl_random 1.2.2.*
  • munkres 1.1.4.*
  • mypy_extensions 1.0.0.*
  • mysql 5.7.24.*
  • nbclassic 0.5.5.*
  • nbclient 0.5.13.*
  • nbconvert 6.5.4.*
  • nbformat 5.9.2.*
  • ncurses 6.4.*
  • nest-asyncio 1.5.6.*
  • notebook 6.5.4.*
  • notebook-shim 0.2.2.*
  • nspr 4.35.*
  • nss 3.89.1.*
  • numexpr 2.8.4.*
  • numpy-base 1.24.3.*
  • openjpeg 2.4.0.*
  • openpyxl 3.0.9.*
  • openssl 3.0.13.*
  • packaging 23.1.*
  • pandas 1.4.2.*
  • pandocfilters 1.5.0.*
  • parso 0.8.3.*
  • pathspec 0.10.3.*
  • patsy 0.5.3.*
  • pcre 8.45.*
  • pexpect 4.8.0.*
  • pickleshare 0.7.5.*
  • pillow 10.0.1.*
  • pip 23.1.2.*
  • platformdirs 3.10.0.*
  • ply 3.11.*
  • prometheus_client 0.14.1.*
  • prompt-toolkit 3.0.36.*
  • prompt_toolkit 3.0.36.*
  • psutil 5.9.0.*
  • ptyprocess 0.7.0.*
  • pure_eval 0.2.2.*
  • pycparser 2.21.*
  • pygments 2.15.1.*
  • pyopenssl 23.2.0.*
  • pyparsing 3.0.9.*
  • pyqt 5.15.7.*
  • pyrsistent 0.18.0.*
  • pysocks 1.7.1.*
  • python 3.10.11.*
  • python-dateutil 2.8.2.*
  • python-fastjsonschema 2.16.2.*
  • python-json-logger 2.0.7.*
  • pytz 2023.3.post1.*
  • pyyaml 6.0.*
  • pyzmq 23.2.0.*
  • qt-main 5.15.2.*
  • qt-webengine 5.15.9.*
  • qtconsole 5.4.2.*
  • qtpy 2.2.0.*
  • qtwebkit 5.212.*
  • readline 8.2.*
  • requests 2.31.0.*
  • rfc3339-validator 0.1.4.*
  • rfc3986-validator 0.1.1.*
  • scipy 1.9.3.*
  • seaborn 0.11.2.*
  • send2trash 1.8.0.*
  • setuptools 68.0.0.*
  • sip 6.6.2.*
  • six 1.16.0.*
  • sniffio 1.2.0.*
  • soupsieve 2.5.*
  • sqlite 3.41.2.*
  • stack_data 0.2.0.*
  • statsmodels 0.13.2.*
  • tabulate 0.8.9.*
  • tbb 2021.8.0.*
  • terminado 0.17.1.*
  • threadpoolctl 2.2.0.*
  • tinycss2 1.2.1.*
  • tk 8.6.12.*
  • toml 0.10.2.*
  • tomli 2.0.1.*
  • tornado 6.3.3.*
  • traitlets 5.7.1.*
  • typing-extensions 4.7.1.*
  • typing_extensions 4.7.1.*
  • tzdata 2023c.*
  • urllib3 1.26.16.*
  • wcwidth 0.2.5.*
  • webencodings 0.5.1.*
  • websocket-client 0.58.0.*
  • wheel 0.41.2.*
  • widgetsnbextension 4.0.5.*
  • xz 5.4.2.*
  • y-py 0.5.9.*
  • yaml 0.2.5.*
  • ypy-websocket 0.8.2.*
  • zeromq 4.3.4.*
  • zipp 3.11.0.*
  • zlib 1.2.13.*
  • zstd 1.5.5.*