Skip to content

GROOT tree

The main class in the GROOT repository is the GrootTreeClassifier, this class implements GROOT as a Scikit-learn compatible classifier. That means you initialize it with all important hyperparameters, then fit it using .fit(X, y) and predict with .predict(X) or .predict_proba(X). The GrootTreeClassifier is also used within the GrootRandomForestClassifier.

Example:

from sklearn.datasets import make_moons
X, y = make_moons(random_state=1)

from groot.model import GrootTreeClassifier
tree = GrootTreeClassifier(max_depth=3, attack_model=[0.1, 0.1])
tree.fit(X, y)
print(tree.score(X, y))
0.9

groot.model.GrootTreeClassifier (BaseGrootTree, ClassifierMixin)

A robust decision tree for binary classification.

__init__(self, max_depth=5, min_samples_split=2, min_samples_leaf=1, max_features=None, robust_weight=1.0, attack_model=None, one_adversarial_class=False, chen_heuristic=False, compile=True, random_state=None) special

Parameters:

Name Type Description Default
max_depth int

The maximum depth for the decision tree once fitted.

5
min_samples_split int

The minimum number of samples required to split a node.

2
min_samples_leaf int

The minimum number of samples required to make a leaf.

1
max_features int or {"sqrt", "log2"}

The number of features to consider while making each split, if None then all features are considered.

None
robust_weight float

The ratio of samples that are actually moved by an adversary.

1.0
attack_model array-like of shape (n_features,)

Attacker capabilities for perturbing X. By default, all features are considered not perturbable.

None
one_adversarial_class bool

Whether one class (malicious, 1) perturbs their samples or if both classes (benign and malicious, 0 and 1) do so.

False
chen_heuristic bool

Whether to use the heuristic for the adversarial Gini impurity from Chen et al. (2019) instead of GROOT's adversarial Gini impurity.

False
compile bool

Whether to compile the tree for faster predictions.

True
random_state int

Controls the sampling of the features to consider when looking for the best split at each node.

None

Attributes:

Name Type Description
classes_ ndarray of shape (n_classes,)

The class labels.

max_features_ int

The inferred value of max_features.

n_samples_ int

The number of samples when fit is performed.

n_features_ int

The number of features when fit is performed.

root_ Node

The root node of the tree after fitting.

compiled_root_ CompiledTree

The compiled root node of the tree after fitting.

predict(self, X)

Predict the classes of the input samples X.

The predicted class is the most frequently occuring class label in a leaf.

Parameters:

Name Type Description Default
X array-like of shape (n_samples, n_features)

The input samples to predict.

required

Returns:

Type Description
array-like of shape (n_samples,)

The predicted class labels

predict_proba(self, X)

Predict class probabilities of the input samples X.

The class probability is the fraction of samples of the same class in the leaf.

Parameters:

Name Type Description Default
X array-like of shape (n_samples, n_features)

The input samples to predict.

required

Returns:

Type Description
array of shape (n_samples,)

The probability for each input sample of being malicious.