GROOT tree
The main class in the GROOT repository is the GrootTreeClassifier
, this class implements GROOT as a Scikit-learn compatible classifier. That means you initialize it with all important hyperparameters, then fit it using .fit(X, y)
and predict with .predict(X)
or .predict_proba(X)
. The GrootTreeClassifier
is also used within the GrootRandomForestClassifier
.
Example:
from sklearn.datasets import make_moons
X, y = make_moons(random_state=1)
from groot.model import GrootTreeClassifier
tree = GrootTreeClassifier(max_depth=3, attack_model=[0.1, 0.1])
tree.fit(X, y)
print(tree.score(X, y))
0.9
groot.model.GrootTreeClassifier (BaseGrootTree, ClassifierMixin)
A robust decision tree for binary classification.
__init__(self, max_depth=5, min_samples_split=2, min_samples_leaf=1, max_features=None, robust_weight=1.0, attack_model=None, one_adversarial_class=False, chen_heuristic=False, compile=True, random_state=None)
special
Parameters:
Name | Type | Description | Default |
---|---|---|---|
max_depth |
int |
The maximum depth for the decision tree once fitted. |
5 |
min_samples_split |
int |
The minimum number of samples required to split a node. |
2 |
min_samples_leaf |
int |
The minimum number of samples required to make a leaf. |
1 |
max_features |
int or {"sqrt", "log2"} |
The number of features to consider while making each split, if None then all features are considered. |
None |
robust_weight |
float |
The ratio of samples that are actually moved by an adversary. |
1.0 |
attack_model |
array-like of shape (n_features,) |
Attacker capabilities for perturbing X. By default, all features are considered not perturbable. |
None |
one_adversarial_class |
bool |
Whether one class (malicious, 1) perturbs their samples or if both classes (benign and malicious, 0 and 1) do so. |
False |
chen_heuristic |
bool |
Whether to use the heuristic for the adversarial Gini impurity from Chen et al. (2019) instead of GROOT's adversarial Gini impurity. |
False |
compile |
bool |
Whether to compile the tree for faster predictions. |
True |
random_state |
int |
Controls the sampling of the features to consider when looking for the best split at each node. |
None |
Attributes:
Name | Type | Description |
---|---|---|
classes_ |
ndarray of shape (n_classes,) |
The class labels. |
max_features_ |
int |
The inferred value of max_features. |
n_samples_ |
int |
The number of samples when |
n_features_ |
int |
The number of features when |
root_ |
Node |
The root node of the tree after fitting. |
compiled_root_ |
CompiledTree |
The compiled root node of the tree after fitting. |
predict(self, X)
Predict the classes of the input samples X.
The predicted class is the most frequently occuring class label in a leaf.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
X |
array-like of shape (n_samples, n_features) |
The input samples to predict. |
required |
Returns:
Type | Description |
---|---|
array-like of shape (n_samples,) |
The predicted class labels |
predict_proba(self, X)
Predict class probabilities of the input samples X.
The class probability is the fraction of samples of the same class in the leaf.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
X |
array-like of shape (n_samples, n_features) |
The input samples to predict. |
required |
Returns:
Type | Description |
---|---|
array of shape (n_samples,) |
The probability for each input sample of being malicious. |