cito

Building and Training Neural Networks with an R interface

https://github.com/citoverse/cito

Science Score: 49.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
    Found 3 DOI reference(s) in README
  • Academic publication links
  • Committers with academic emails
    2 of 4 committers (50.0%) from academic institutions
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (20.2%) to scientific vocabulary

Keywords

machine-learning neural-network r r-package
Last synced: 7 months ago · JSON representation

Repository

Building and Training Neural Networks with an R interface

Basic Info
Statistics
  • Stars: 48
  • Watchers: 3
  • Forks: 6
  • Open Issues: 50
  • Releases: 1
Topics
machine-learning neural-network r r-package
Created about 4 years ago · Last pushed 8 months ago
Metadata Files
Readme Changelog License

README.Rmd

---
output: github_document
---



```{r, include = FALSE, echo=FALSE}
options("progress_enabled" = TRUE)
knitr::opts_chunk$set(fig.path = "man/figures/README-", collapse=TRUE)
```

# cito

[![Project Status: Active -- The project has reached a stable, usable state and is being actively developed.](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active) [![License: GPL v3](https://img.shields.io/badge/License-GPL%20v3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0) [![CRAN_Status_Badge](http://www.r-pkg.org/badges/version/cito)](https://cran.r-project.org/package=cito) [![R-CMD-check](https://github.com/citoverse/cito/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/citoverse/cito/actions/workflows/R-CMD-check.yaml)
[![Publication](https://img.shields.io/badge/Publication-10.1111/ecog.07143-green.svg)](https://doi.org/10.1111/ecog.07143)



The 'cito' package provides a user-friendly interface for training and interpreting deep neural networks (DNN). 'cito' simplifies the fitting of DNNs by supporting the familiar formula syntax, hyperparameter tuning under cross-validation, and helps to detect and handle convergence problems.  DNNs can be trained on CPU, GPU and MacOS GPUs. In addition, 'cito' has many downstream functionalities such as various explainable AI (xAI) metrics (e.g. variable importance, partial dependence plots, accumulated local effect plots, and effect estimates) to interpret trained DNNs. 'cito' optionally provides confidence intervals (and p-values) for all xAI metrics and predictions. At the same time, 'cito' is computationally efficient because it is based on the deep learning framework 'torch'. The 'torch' package is native to R, so no Python installation or other API is required for this package.


## Installation

Before installing 'cito' make sure 'torch' is installed. See the code chunk below if you are unsure on how to check this

```{r, eval = FALSE}
# check package 
if(!require('torch',quietly = TRUE)) install.packages('torch')
library('torch') 

#install torch
if(!torch_is_installed()) install_torch()
```

If you have trouble installing 'torch', please [visit the website of the 'torch' package](https://torch.mlverse.org/docs/articles/installation.html) or create an issue on [our github website](https://github.com/citoverse/cito/issues). We are happy to help you.

A stable version of cito from CRAN can be installed with:

```{r, eval=FALSE}
install.packages("cito")

```

The development version from [GitHub](https://github.com/) can be installed by:

```{r, eval=FALSE}
if(!require('devtools', quietly = TRUE)) install.packages('devtools')
devtools::install_github('citoverse/cito')
```

## Example

Once installed, the main function `dnn()` can be used. See the example below. A more in depth explanation can be found in the vignettes or [here under articles](https://citoverse.github.io/cito/).

1.  Fit model with bootstrapping (to obtain confidence intervals). All methods work with and without bootstrapping

```{r}
library(cito)
nn.fit <- dnn(Sepal.Length~., data = datasets::iris, bootstrap = 30L)
```

2.  Check if models have converged (compare training loss against baseline loss (=intercept only model)):

```{r, eval=FALSE}
analyze_training(nn.fit)
# At 1st glance, the networks converged since the loss is lower than the baseline loss and the training loss is on a plateau at the end of the training.
```

3.  Plot model architecture

```{r}
plot(nn.fit)
```

4.  'cito' supports many advanced functionalities such as common explainable AI metrics that can be used for inference (i.e. to interpret the models). Variable importance (similar to a variation partitioning) and linear effects are directly returned by the `summary` function:

```{r}
summary(nn.fit)
```

5.  Predict (with confidence intervals):

```{r}
dim(predict(nn.fit, newdata = datasets::iris))
```

### Hyperparameter tuning

Certain arguments/parameters such as the architecture, activation function, and the learning rate can be automatically tuned under crossvalidation (for a full list, see `?dnn`). Parameters that should be tuned, can be flagged by using the function `tune()` instead of a hyperparameter value:

```{r}
nn.fit <- dnn(Sepal.Length~., data = datasets::iris, lr = tune(0.0001, 0.1))
nn.fit$tuning
```

The tuning can be configured with `tuning=config_tuning()`. After tuning, a final model trained with the best hyperparameters is returned. Hyperparameter combinations that do not achieve a loss below the baseline loss will be aborted early and not fully cross-validated. These runs are given a test loss of infinity.

## Advanced

We can pass custom loss functions to 'cito', optionally with additional parameters that should be fitted. The only requirement is that all calculations must be written using the 'torch' package (cito automatically converts the initial values of the custom parameters to 'torch' objects).

We use a multivariate normal distribution as the likelihood function and we want to parameterize/fit the covariance matrix of the multivariate normal distribution:

1.  We need one helper function, `create_cov()` that builds the covariance matrix based on a lower triangular matrix and the diagonals (low-rank approximation of the covariance matrix)

2.  We need our custom likelihood function which uses the `distr_multivariate_normal(…)` function from the torch package:

```{r}
create_cov = function(L, Diag) {
  return(torch::torch_matmul(L, L$t()) + torch::torch_diag(Diag$exp()+0.001))
}

custom_loss_MVN = function(true, pred) {
  Sigma = create_cov(SigmaPar, SigmaDiag)
  logLik = torch::distr_multivariate_normal(pred,
                                            covariance_matrix = Sigma)$
    log_prob(true)
  return(-logLik$mean())
}

```

3.  We use "SigmaPar" and "SigmaDiag" as parameters that we want to optimize along the DNN. We will pass a named list with starting values to 'cito' and 'cito' will infer automatically (based on the R shape) the shape of the parameters:

```{r, warning=FALSE}
nn.fit<- dnn(cbind(Sepal.Length, Sepal.Width, Petal.Length)~.,
             data = datasets::iris,
             lr = 0.01,
             epochs = 200L,
             loss = custom_loss_MVN,
             verbose = FALSE,
             plot = FALSE,
             custom_parameters =
               list(SigmaDiag =  rep(0, 3), # Our parameters with starting values
                    SigmaPar = matrix(rnorm(6, sd = 0.001), 3, 2)) # Our parameters with starting values
)
```

Estimated covariance matrix:

```{r}
as.matrix(create_cov(nn.fit$loss$parameter$SigmaPar,
                     nn.fit$loss$parameter$SigmaDiag))
```

Empirical covariance matrix:

```{r}
cov(predict(nn.fit) - nn.fit$data$Y)
```


Owner

  • Name: citoverse
  • Login: citoverse
  • Kind: organization

GitHub Events

Total
  • Issues event: 27
  • Watch event: 13
  • Issue comment event: 12
  • Push event: 11
  • Pull request review event: 1
  • Pull request event: 6
  • Create event: 1
Last Year
  • Issues event: 27
  • Watch event: 13
  • Issue comment event: 12
  • Push event: 11
  • Pull request review event: 1
  • Pull request event: 6
  • Create event: 1

Committers

Last synced: 10 months ago

All Time
  • Total Commits: 340
  • Total Committers: 4
  • Avg Commits per committer: 85.0
  • Development Distribution Score (DDS): 0.509
Past Year
  • Commits: 27
  • Committers: 3
  • Avg Commits per committer: 9.0
  • Development Distribution Score (DDS): 0.481
Top Committers
Name Email Commits
Christian Amesöder c****r@s****e 167
dwjak123lkdmaKOP M****r@b****e 96
Armin Schenk a****9@g****m 73
Florian Hartig f****g 4
Committer Domains (Top 20 + Academic)

Issues and Pull Requests

Last synced: 8 months ago

All Time
  • Total issues: 96
  • Total pull requests: 14
  • Average time to close issues: 4 months
  • Average time to close pull requests: 11 days
  • Total issue authors: 8
  • Total pull request authors: 4
  • Average comments per issue: 0.57
  • Average comments per pull request: 0.29
  • Merged pull requests: 12
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 23
  • Pull requests: 3
  • Average time to close issues: 5 months
  • Average time to close pull requests: 14 days
  • Issue authors: 4
  • Pull request authors: 2
  • Average comments per issue: 0.96
  • Average comments per pull request: 0.33
  • Merged pull requests: 2
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • MaximilianPi (43)
  • florianhartig (16)
  • ArminSchenk (15)
  • ChristianAmes (8)
  • ChrKoenig (4)
  • melinakienitz (2)
  • rajoma-02 (1)
  • D-Maar (1)
  • hakimabdi (1)
  • tripartio (1)
Pull Request Authors
  • ArminSchenk (13)
  • MaximilianPi (5)
  • ChristianAmes (2)
  • florianhartig (2)
Top Labels
Issue Labels
enhancement (11) bug (9)
Pull Request Labels

Packages

  • Total packages: 1
  • Total downloads:
    • cran 354 last-month
  • Total dependent packages: 0
  • Total dependent repositories: 0
  • Total versions: 4
  • Total maintainers: 1
cran.r-project.org: cito

Building and Training Neural Networks

  • Versions: 4
  • Dependent Packages: 0
  • Dependent Repositories: 0
  • Downloads: 354 Last month
Rankings
Stargazers count: 17.0%
Forks count: 17.8%
Average: 27.0%
Dependent packages count: 29.8%
Downloads: 34.8%
Dependent repos count: 35.5%
Last synced: 8 months ago

Dependencies

.github/workflows/R-CMD-check.yaml actions
  • actions/cache v2 composite
  • actions/checkout v2 composite
  • actions/upload-artifact main composite
  • r-lib/actions/setup-pandoc v2 composite
  • r-lib/actions/setup-r v2 composite
  • r-lib/actions/setup-tinytex v2 composite
DESCRIPTION cran
  • R >= 3.5 depends
  • checkmate * imports
  • coro * imports
  • gridExtra * imports
  • torch * imports
  • ggplot2 * suggests
  • ggraph * suggests
  • igraph * suggests
  • knitr * suggests
  • plotly * suggests
  • rmarkdown * suggests
  • stats * suggests
  • testthat * suggests