# GROOT forest

The `GrootRandomForestClassifier`

class uses bootstrap aggregation and partially random feature selection to train an ensemble of `GrootTreeClassifier`

s. On datasets with many features, a `GrootRandomForestClassifier`

might perform better than a `GrootTreeClassifier`

as it is not limited in the number of features it can use by a maximum size.

**Example:**

```
from sklearn.datasets import make_moons
X, y = make_moons(random_state=1)
from groot.model import GrootRandomForestClassifier
forest = GrootRandomForestClassifier(attack_model=[0.1, 0.1], random_state=1)
forest.fit(X, y)
print(forest.score(X, y))
```

```
1.0
```

##
```
groot.model.GrootRandomForestClassifier (BaseGrootRandomForest, ClassifierMixin)
```

A robust random forest for binary classification.

###
`__init__(self, n_estimators=100, max_depth=None, max_features='sqrt', min_samples_split=2, min_samples_leaf=1, robust_weight=1.0, attack_model=None, one_adversarial_class=False, verbose=False, chen_heuristic=False, max_samples=None, n_jobs=None, compile=True, random_state=None)`

`special`

**Parameters:**

Name | Type | Description | Default |
---|---|---|---|

`n_estimators` |
`int` |
The number of decision trees to fit in the forest. |
`100` |

`max_depth` |
`int` |
The maximum depth for the decision trees once fitted. |
`None` |

`max_features` |
`int or {"sqrt", "log2", None}` |
The number of features to consider while making each split, if None then all features are considered. |
`'sqrt'` |

`min_samples_split` |
`int` |
The minimum number of samples required to split a tree node. |
`2` |

`min_samples_leaf` |
`int` |
The minimum number of samples required to make a tree leaf. |
`1` |

`robust_weight` |
`float` |
The ratio of samples that are actually moved by an adversary. |
`1.0` |

`attack_model` |
`array-like of shape (n_features,)` |
Attacker capabilities for perturbing X. The attack model needs to describe for every feature in which way it can be perturbed. |
`None` |

`one_adversarial_class` |
`bool` |
Whether one class (malicious, 1) perturbs their samples or if both classes (benign and malicious, 0 and 1) do so. |
`False` |

`verbose` |
`bool` |
Whether to print fitting progress on screen. |
`False` |

`chen_heuristic` |
`bool` |
Whether to use the heuristic for the adversarial Gini impurity from Chen et al. (2019) instead of GROOT's adversarial Gini impurity. |
`False` |

`max_samples` |
`float` |
The fraction of samples to draw from X to train each decision tree. If None (default), then draw X.shape[0] samples. |
`None` |

`n_jobs` |
`int` |
The number of jobs to run in parallel when fitting trees. See joblib. |
`None` |

`compile` |
`bool` |
Whether to compile decision trees for faster predictions. |
`True` |

`random_state` |
`int` |
Controls the sampling of the features to consider when looking for the best split at each node. |
`None` |

**Attributes:**

Name | Type | Description |
---|---|---|

`estimators_` |
`list of GrootTree` |
The collection of fitted sub-estimators. |

`n_samples_` |
`int` |
The number of samples when |

`n_features_` |
`int` |
The number of features when |

###
`predict(self, X)`

Predict the classes of the input samples X.

The predicted class is the rounded average of the class labels in each predicted leaf.

**Parameters:**

Name | Type | Description | Default |
---|---|---|---|

`X` |
`array-like of shape (n_samples, n_features)` |
The input samples to predict. |
required |

**Returns:**

Type | Description |
---|---|

`array-like of shape (n_samples,)` |
The predicted class labels |

###
`predict_proba(self, X)`

Predict class probabilities of the input samples X.

The class probability is the average of the probabilities predicted by each decision tree. The probability prediction of each tree is the fraction of samples of the same class in the leaf.

**Parameters:**

Name | Type | Description | Default |
---|---|---|---|

`X` |
`array-like of shape (n_samples, n_features)` |
The input samples to predict. |
required |

**Returns:**

Type | Description |
---|---|

`array of shape (n_samples,)` |
The probability for each input sample of being malicious. |