Skip to content

GROOT forest

The GrootRandomForestClassifier class uses bootstrap aggregation and partially random feature selection to train an ensemble of GrootTreeClassifiers. On datasets with many features, a GrootRandomForestClassifier might perform better than a GrootTreeClassifier as it is not limited in the number of features it can use by a maximum size.


from sklearn.datasets import make_moons
X, y = make_moons(random_state=1)

from groot.model import GrootRandomForestClassifier
forest = GrootRandomForestClassifier(attack_model=[0.1, 0.1], random_state=1), y)
print(forest.score(X, y))

groot.model.GrootRandomForestClassifier (BaseGrootRandomForest, ClassifierMixin)

A robust random forest for binary classification.

__init__(self, n_estimators=100, max_depth=None, max_features='sqrt', min_samples_split=2, min_samples_leaf=1, robust_weight=1.0, attack_model=None, one_adversarial_class=False, verbose=False, chen_heuristic=False, max_samples=None, n_jobs=None, compile=True, random_state=None) special


Name Type Description Default
n_estimators int

The number of decision trees to fit in the forest.

max_depth int

The maximum depth for the decision trees once fitted.

max_features int or {"sqrt", "log2", None}

The number of features to consider while making each split, if None then all features are considered.

min_samples_split int

The minimum number of samples required to split a tree node.

min_samples_leaf int

The minimum number of samples required to make a tree leaf.

robust_weight float

The ratio of samples that are actually moved by an adversary.

attack_model array-like of shape (n_features,)

Attacker capabilities for perturbing X. The attack model needs to describe for every feature in which way it can be perturbed.

one_adversarial_class bool

Whether one class (malicious, 1) perturbs their samples or if both classes (benign and malicious, 0 and 1) do so.

verbose bool

Whether to print fitting progress on screen.

chen_heuristic bool

Whether to use the heuristic for the adversarial Gini impurity from Chen et al. (2019) instead of GROOT's adversarial Gini impurity.

max_samples float

The fraction of samples to draw from X to train each decision tree. If None (default), then draw X.shape[0] samples.

n_jobs int

The number of jobs to run in parallel when fitting trees. See joblib.

compile bool

Whether to compile decision trees for faster predictions.

random_state int

Controls the sampling of the features to consider when looking for the best split at each node.



Name Type Description
estimators_ list of GrootTree

The collection of fitted sub-estimators.

n_samples_ int

The number of samples when fit is performed.

n_features_ int

The number of features when fit is performed.

predict(self, X)

Predict the classes of the input samples X.

The predicted class is the rounded average of the class labels in each predicted leaf.


Name Type Description Default
X array-like of shape (n_samples, n_features)

The input samples to predict.



Type Description
array-like of shape (n_samples,)

The predicted class labels

predict_proba(self, X)

Predict class probabilities of the input samples X.

The class probability is the average of the probabilities predicted by each decision tree. The probability prediction of each tree is the fraction of samples of the same class in the leaf.


Name Type Description Default
X array-like of shape (n_samples, n_features)

The input samples to predict.



Type Description
array of shape (n_samples,)

The probability for each input sample of being malicious.