rjaf

rjaf: Regularized Joint Assignment Forest with Treatment Arm Clustering - Published in JOSS (2025)

https://github.com/wustat/rjaf

Science Score: 95.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
    Found 3 DOI reference(s) in README and JOSS metadata
  • Academic publication links
    Links to: joss.theoj.org
  • Committers with academic emails
    4 of 12 committers (33.3%) from academic institutions
  • Institutional organization owner
  • JOSS paper metadata
    Published in Journal of Open Source Software

Keywords

causal-inference machine-learning
Last synced: 6 months ago · JSON representation

Repository

Regularized Joint Assignment Forest with Treatment Arm Clustering

Basic Info
  • Host: GitHub
  • Owner: wustat
  • License: agpl-3.0
  • Language: R
  • Default Branch: main
  • Homepage:
  • Size: 9.09 MB
Statistics
  • Stars: 0
  • Watchers: 4
  • Forks: 0
  • Open Issues: 0
  • Releases: 4
Topics
causal-inference machine-learning
Created over 4 years ago · Last pushed 11 months ago
Metadata Files
Readme Contributing License

README.Rmd

---
output: github_document
---



```{r, include = FALSE}
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.path = "man/figures/README-",
  out.width = "100%"
)
```

# rjaf


[![CRANstatus](https://www.r-pkg.org/badges/version/rjaf)](https://cran.r-project.org/package=rjaf)
[![](https://cranlogs.r-pkg.org/badges/grand-total/rjaf)](https://cran.r-project.org/package=rjaf)
[![License: GPL
v3](https://img.shields.io/badge/License-GPLv3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0)
[![Project Status: Active – The project has reached a stable, usable
state and is being actively
developed.](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active)
[![status](https://joss.theoj.org/papers/ff8fa725cc40d0247158bd244f1117be/status.svg)](https://joss.theoj.org/papers/ff8fa725cc40d0247158bd244f1117be)


> Regularized Joint Assignment Forests with Treatment Arm Clustering

Wenbo Wu, Xinyi Zhang, Jann Spiess, Rahul Ladhania

---
## Introduction

`rjaf` is an `R` package that implements a regularized and clustered joint assignment forest which targets joint assignment to one of many treatment arms as described in Ladhania, Spiess, Ungar, and Wu (2023). It utilizes a regularized forest-based greedy recursive algorithm to shrink effect estimates across arms and a clustering approach to combine treatment arm with similar outcomes. The optimal treatment assignment is estimated by pooling information across treatment arms. In this tutorial we introduce the use of `rjaf` through an example data set.



## Installation
`rjaf` can be installed from CRAN with
```
install.packages("rjaf")
```

The stable, development version can be installed from GitHub with
```
require("devtools")
require("remotes")
devtools::install_github("wustat/rjaf", subdir = "r-package/rjaf")
```

## What is regularized and clustered joint assignment forest (rjaf)? 
The algorithm aims to train a joint forest model to estimate the optimal treatment assignment by pooling information across treatment arms.

It first obtains an assignment forest by bagging trees as described in Kallus (2017),  with covariate and treatment arm randomization for each tree. Then, it generates "honest" and regularized estimates of the treatment-specific counterfactual outcomes on the training sample following Wager and Athey (2018).

Like Bonhomme and Manresa (2015), it uses a clustering of treatment arms when constructing the assignment trees. It employs a k-means algorithm for clustering the K treatment arms into M treatment groups based on the K predictions for each of the n units in the training sample. After clustering, it then repeats the assignment-forest algorithm on the full training data with M+1 (including control) "arms" (where data from the original arms are combined by groups) to obtain an ensemble of trees.

The following scripts demonstrate the function ``rjaf()``, which constructs a joint forest model to estimate the optimal treatment assignment by pooling information across treatment arms using a clustering scheme. By inputting training, estimation, and held-out data, we can obtain final regularized predictions and assignments in ``forest.reg``, where the algorithm estimates regularized averages separately by the original treatment arms $k \in \{0,\ldots,K\}$ and obtains the corresponding assignment.


## Example

```{r, message = FALSE}
library(rjaf)
```


We use a dataset simulated by `sim.data()` under the example section of `rjaf.R`. This dataset contains a total of 100 rows and 5 treatment arms, with a total of 12 covariates as documented in `data.R`. After dividing the `Example_data` into training-estimation and held-out sets, we can obtain regularized averages by 5 treatment arms and optimal treatment assignments. 

Our algorithm returns a list named `forest.reg`, which includes two tibbles named `fitted` and `counterfactuals`. The `fitted` contains individual IDs from the held-out set, optimal treatment arms identified (`trt.rjaf`), predicted optimal outcomes (`Y.rjaf`), and treatment arm clusters (`clus.rjaf`). In this example, since we know the counterfactual outcomes, we include those under optimal treatment arms `trt.rjaf` identified by the algorithm in `fitted` as `Y.cf`. The tibble `counterfactuals` contains estimated counterfactual outcomes for every treatment arm. If performing clustering, the tibble `xwalk` is also returned by the algorithm. `xwalk` has the treatments and their assigned cluster memberships (based on the k-means algorithm).

```{r, message = FALSE}
library(magrittr)
library(dplyr)

# prepare training, estimation, and heldout data
data("Example_data")
set.seed(1)
# training and estimation
data.trainest <- Example_data %>% 
                  slice_sample(n=floor(0.5*nrow(Example_data)))
# held-out
data.heldout <- Example_data %>% 
                  filter(!id %in% data.trainest$id)

# specify variables needed
id <- "id"; y <- "Y"; trt <- "trt";  
vars <- paste0("X", 1:3); prob <- "prob";

# calling the ``rjaf`` function and implement clustering scheme
forest.reg <- rjaf(data.trainest, data.heldout, y, id, trt, vars, 
                   prob, clus.max=3, clus.tree.growing=TRUE)

```

```{r}
head(forest.reg$fitted)
head(forest.reg$counterfactuals)
head(forest.reg$xwalk)
```


## References

Bonhomme, Stéphane and Elena Manresa (2015). Grouped Patterns of Heterogeneity in Panel Data. *Econometrica*, 83: 1147-1184.
\cr

Kallus, Nathan (2017). Recursive Partitioning for Personalization using Observational Data. In Precup, Doina and Yee Whye Teh, editors, Proceedings of the 34th International Conference on Machine Learning, *Proceedings of the 34th International Conference on Machine Learning*, PMLR 70:1789-1798.
\cr

Ladhania, Rahul, Jann Spiess, Lyle Ungar, and Wenbo Wu (2023). Personalized Assignment to One of Many Treatment Arms via Regularized and Clustered Joint Assignment Forests. https://doi.org/10.48550/arXiv.2311.00577.
\cr

Wager, Stefan and Susan Athey (2018). Estimation and inference of heterogeneous treatment effects using random forests. *Journal of the American Statistical Association*, 113(523):1228–1242.
\cr

Owner

  • Name: Wenbo Wu
  • Login: wustat
  • Kind: user
  • Location: New York, NY
  • Company: New York University

Assistant Professor of Population Health (Biostatistics), Medicine (Nephrology), and Data Science at New York University

JOSS Publication

rjaf: Regularized Joint Assignment Forest with Treatment Arm Clustering
Published
April 09, 2025
Volume 10, Issue 108, Page 7843
Authors
Wenbo Wu ORCID
Department of Population Health, NYU Grossman School of Medicine, USA
Xinyi Zhang ORCID
Department of Population Health, NYU Grossman School of Medicine, USA
Jann Spiess ORCID
Graduate School of Business, Stanford University, USA
Rahul Ladhania ORCID
Departments of Health Management and Policy and Biostatistics, University of Michigan School of Public Health, USA
Editor
Vissarion Fisikopoulos ORCID
Tags
machine learning causal inference multi-arm randomized controlled trial heterogeneous treatment effects personalized treatment rules optimal assignment

GitHub Events

Total
  • Create event: 3
  • Release event: 4
  • Issues event: 7
  • Delete event: 1
  • Push event: 73
Last Year
  • Create event: 3
  • Release event: 4
  • Issues event: 7
  • Delete event: 1
  • Push event: 73

Committers

Last synced: 7 months ago

All Time
  • Total Commits: 230
  • Total Committers: 12
  • Avg Commits per committer: 19.167
  • Development Distribution Score (DDS): 0.652
Past Year
  • Commits: 111
  • Committers: 10
  • Avg Commits per committer: 11.1
  • Development Distribution Score (DDS): 0.468
Top Committers
Name Email Commits
Xinyi Zhang 5****g 80
Rahul Ladhania r****a@g****m 30
wustat w****u@u****u 30
Wenbo Wu w****u@w****g 30
Xinyi Zhang z****i@u****u 20
Wenbo Wu w****u@w****m 20
Wenbo Wu w****u@w****m 7
Wenbo Wu w****u@m****u 7
Wenbo Wu w****u@w****m 3
jspiess j****s@s****u 1
Zhang x****5@b****r 1
Wenbo Wu w****u@W****t 1

Issues and Pull Requests

Last synced: 6 months ago

All Time
  • Total issues: 7
  • Total pull requests: 0
  • Average time to close issues: 3 months
  • Average time to close pull requests: N/A
  • Total issue authors: 4
  • Total pull request authors: 0
  • Average comments per issue: 0.57
  • Average comments per pull request: 0
  • Merged pull requests: 0
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 4
  • Pull requests: 0
  • Average time to close issues: about 1 month
  • Average time to close pull requests: N/A
  • Issue authors: 3
  • Pull request authors: 0
  • Average comments per issue: 0.25
  • Average comments per pull request: 0
  • Merged pull requests: 0
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • rladhania (4)
  • wenjie2wang (1)
  • limengbinggz (1)
  • XinyiEmilyZhang (1)
Pull Request Authors
Top Labels
Issue Labels
Pull Request Labels

Packages

  • Total packages: 1
  • Total downloads:
    • cran 493 last-month
  • Total dependent packages: 0
  • Total dependent repositories: 0
  • Total versions: 4
  • Total maintainers: 1
cran.r-project.org: rjaf

Regularized Joint Assignment Forest with Treatment Arm Clustering

  • Versions: 4
  • Dependent Packages: 0
  • Dependent Repositories: 0
  • Downloads: 493 Last month
Rankings
Dependent packages count: 27.8%
Dependent repos count: 34.3%
Average: 49.7%
Downloads: 87.1%
Maintainers (1)
Last synced: 6 months ago

Dependencies

r-package/rjaf/DESCRIPTION cran
  • R >= 3.0 depends
  • Rcpp * imports
  • dplyr * imports
  • forcats * imports
  • magrittr * imports
  • randomForest * imports
  • ranger * imports
  • readr * imports
  • rlang >= 1.1.0 imports
  • stringr * imports
  • tibble * imports
  • tidyr * imports
  • knitr * suggests
  • rmarkdown * suggests
  • testthat >= 3.0.0 suggests