decision-jungles
Python implementation of the Decision Jungles paper by Microsoft (compatible with scikit-learn)
Science Score: 44.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
○Academic publication links
-
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (11.9%) to scientific vocabulary
Repository
Python implementation of the Decision Jungles paper by Microsoft (compatible with scikit-learn)
Basic Info
- Host: GitHub
- Owner: mendelevium
- License: mit
- Language: Python
- Default Branch: main
- Size: 765 KB
Statistics
- Stars: 0
- Watchers: 1
- Forks: 0
- Open Issues: 0
- Releases: 0
Metadata Files
README.md
Decision Jungles
A scikit-learn compatible implementation of Decision Jungles as described in the paper "Decision Jungles: Compact and Rich Models for Classification" by Jamie Shotton et al. (NIPS 2013).
Overview
Decision Jungles are ensembles of rooted decision directed acyclic graphs (DAGs) that offer two key advantages over traditional decision trees/forests:
- Reduced memory footprint through node merging
- Improved generalization through regularization effects of the DAG structure
Unlike conventional decision trees that only allow one path to every node, a DAG in a decision jungle allows multiple paths from the root to each leaf. This results in a more compact model with potentially better generalization.
Installation
Requirements
- Python 3.8 or higher
- NumPy 1.17.0 or higher
- scikit-learn 0.21.0 or higher
- scipy 1.3.0 or higher
Basic Installation
bash
pip install decision-jungles
or
bash
pip install git+https://github.com/mendelevium/decision-jungles.git
Performance Optimization with Cython
```bash
Install with Cython dependencies for improved performance
pip install decision-jungles[performance]
Install with memory profiling and benchmarking tools
pip install decision-jungles[profiling]
Install with development dependencies (testing, etc.)
pip install decision-jungles[dev]
Install all optional dependencies
pip install decision-jungles[performance,profiling,dev] ```
For better performance, you can compile the Cython extensions:
```bash
Install Cython
pip install cython
Clone the repository (if installing from source)
git clone https://github.com/mendelevium/decision-jungles.git cd decision-jungles
Compile Cython extensions
python setupcython.py buildext --inplace ```
This will significantly speed up the training process, especially for large datasets.
Usage
Classification
```python from decisionjungles import DecisionJungleClassifier from sklearn.datasets import loadiris from sklearn.modelselection import traintest_split
Load data
X, y = loadiris(returnXy=True) Xtrain, Xtest, ytrain, ytest = traintestsplit(X, y, testsize=0.2, random_state=42)
Train a Decision Jungle
clf = DecisionJungleClassifier( nestimators=10, maxwidth=256, maxdepth=10, randomstate=42 ) clf.fit(Xtrain, ytrain)
Make predictions
ypred = clf.predict(Xtest) print(f"Accuracy: {clf.score(Xtest, ytest):.4f}")
Get memory usage
print(f"Memory usage: {clf.getmemoryusage()} bytes") print(f"Number of nodes: {clf.getnodecount()}") ```
Regression
```python from decisionjungles import DecisionJungleRegressor from sklearn.datasets import loaddiabetes from sklearn.modelselection import traintestsplit from sklearn.metrics import r2score
Load data
X, y = loaddiabetes(returnXy=True) Xtrain, Xtest, ytrain, ytest = traintestsplit(X, y, testsize=0.2, random_state=42)
Train a Decision Jungle for regression
reg = DecisionJungleRegressor( nestimators=10, maxwidth=256, criterion="mse", # Use "mse" or "mae" randomstate=42 ) reg.fit(Xtrain, y_train)
Make predictions
ypred = reg.predict(Xtest) print(f"R² score: {r2score(ytest, y_pred):.4f}")
Get memory usage
print(f"Memory usage: {reg.getmemoryusage()} bytes") print(f"Number of nodes: {reg.getnodecount()}") ```
Key Features
- Scikit-learn compatible API for both classification and regression
- Two node merging algorithms: LSearch and ClusterSearch
- Various merging schedules for different applications
- Memory-efficient implementation with significant space savings compared to Random Forests
- Support for both classification (gini, entropy) and regression (MSE, MAE) criteria
- Visualization utilities for model inspection and interpretation
- Performance metrics and memory profiling tools
- Robust model serialization with pickle and joblib
- Direct support for categorical features without preprocessing
- Feature importance calculation
- Early stopping functionality to prevent overfitting
Parameters
Classification
The DecisionJungleClassifier accepts the following parameters:
n_estimators(int, default=10): Number of DAGs in the jungle.max_width(int, default=256): Maximum width of each level (M parameter in the paper).max_depth(int, default=None): Maximum depth of the DAGs.min_samples_split(int, default=2): Minimum number of samples required to split a node.min_samples_leaf(int, default=1): Minimum number of samples required at a leaf node.min_impurity_decrease(float, default=0.0): Minimum impurity decrease required for a split.max_features(int, float, str, default="sqrt"): Number of features to consider for best split.random_state(int, default=None): Random seed for reproducibility.merging_schedule(str, default="exponential"): Type of merging schedule to use.n_jobs(int, default=None): Number of jobs to run in parallel.categorical_features(array-like or str, default=None): Specifies which features are categorical.early_stopping(bool, default=False): Whether to use early stopping during training.
Regression
The DecisionJungleRegressor accepts similar parameters with these differences:
criterion(str, default="mse"): Function to measure split quality:- "mse": Mean squared error minimization
- "mae": Mean absolute error minimization
max_features(int, float, str, default="auto"): Number of features to consider for best split.
Comparison with Decision Forests
Decision Jungles offer several advantages over traditional Decision Forests:
- Memory Efficiency: Jungles require dramatically less memory while often improving generalization.
- Improved Generalization: Node merging can lead to better regularization and improved test accuracy.
- Inference Time: For the same memory footprint, jungles can achieve higher accuracy than forests.
Examples
Check the examples/ directory for various usage examples:
- Basic usage
- Comparison with scikit-learn's Random Forests
- Memory usage analysis
- Performance on different datasets
- Hyperparameter tuning
- Categorical features handling
- Model serialization and loading
- Early stopping for preventing overfitting
- Integration with scikit-learn pipelines
Model Serialization
Decision Jungle models can be easily saved and loaded using standard Python serialization mechanisms:
```python import pickle import joblib from decision_jungles import DecisionJungleClassifier
Train a model
jungle = DecisionJungleClassifier(nestimators=10) jungle.fit(Xtrain, y_train)
Method 1: Save model using pickle
with open("jungle_model.pkl", "wb") as f: pickle.dump(jungle, f)
Method 1: Load model using pickle
with open("junglemodel.pkl", "rb") as f: loadedjungle = pickle.load(f)
Method 2: Save model using joblib (better for large models)
joblib.dump(jungle, "jungle_model.joblib")
Method 2: Load model using joblib
loadedjungle = joblib.load("junglemodel.joblib")
Make predictions with the loaded model
predictions = loadedjungle.predict(Xtest) ```
Documentation
Regression
For detailed information about the regression functionality, please see the README_REGRESSION.md file, which includes:
- Detailed implementation overview
- Advanced usage examples
- Performance comparisons with Random Forest Regressors
- Implementation details and architecture
Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
License
This project is licensed under the MIT License - see the LICENSE file for details.
Citation
If you use Decision Jungles in your research, please cite the original paper:
@inproceedings{shotton2013decision,
title={Decision jungles: Compact and rich models for classification},
author={Shotton, Jamie and Sharp, Toby and Kohli, Pushmeet and Nowozin, Sebastian and Winn, John and Criminisi, Antonio},
booktitle={Advances in Neural Information Processing Systems},
pages={234--242},
year={2013}
}
Owner
- Name: Martin Dionne
- Login: mendelevium
- Kind: user
- Location: Montreal, Canada
- Website: http://martindionne.com
- Repositories: 70
- Profile: https://github.com/mendelevium
BI Analyst, B.Eng. GradDip Management & Data Science
Citation (CITATION.cff)
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- family-names: "Dionne"
email: "fertilis.ai@gmail.com"
affiliation: "Decision Jungle Project"
title: "Decision Jungles: A Scikit-learn Compatible Implementation"
version: 0.1.0
date-released: 2025-05-05
url: "https://github.com/example/decision-jungles"
license: MIT
repository-code: "https://github.com/mendelevium/decision-jungles"
keywords:
- "machine learning"
- "decision jungles"
- "classification"
- "regression"
- "directed acyclic graphs"
references:
- type: article
authors:
- family-names: "Shotton"
given-names: "Jamie"
- family-names: "Sharp"
given-names: "Toby"
- family-names: "Kohli"
given-names: "Pushmeet"
- family-names: "Nowozin"
given-names: "Sebastian"
- family-names: "Winn"
given-names: "John"
- family-names: "Criminisi"
given-names: "Antonio"
title: "Decision Jungles: Compact and Rich Models for Classification"
year: 2013
conference:
name: "Advances in Neural Information Processing Systems (NIPS)"
pages: "234-242"
GitHub Events
Total
- Push event: 2
- Create event: 2
Last Year
- Push event: 2
- Create event: 2
Dependencies
- actions/checkout v3 composite
- actions/setup-python v4 composite
- actions/checkout v3 composite
- actions/setup-python v4 composite
- hypothesis >=6.0.0
- joblib >=0.13.0
- matplotlib >=3.1.0
- memory-profiler >=0.60.0
- networkx >=2.3
- numpy >=1.17.0
- pandas >=1.3.0
- psutil >=5.9.0
- pympler >=1.0.1
- pytest >=5.0.0
- scikit-learn >=0.21.0
- scipy >=1.3.0
- tabulate >=0.8.0
- joblib >=0.13.0
- matplotlib >=3.1.0
- networkx >=2.3
- numpy >=1.17.0
- scikit-learn >=0.21.0
- scipy >=1.3.0