Decision Trees — Applied Machine Learning in Python

PHOTO EMBED

Fri Nov 03 2023 07:34:59 GMT+0000 (Coordinated Universal Time)

Saved by @elham469

from sklearn.model_selection import GridSearchCV
param_grid = {'max_leaf_nodes': range(2, 20)}
grid = GridSearchCV(DecisionTreeClassifier(random_state=0), param_grid=param_grid,
                    cv=StratifiedShuffleSplit(100, random_state=1),
                   return_train_score=True)
grid.fit(X_train, y_train)

scores = pd.DataFrame(grid.cv_results_)
scores.plot(x='param_max_leaf_nodes', y=['mean_train_score', 'mean_test_score'], ax=plt.gca())
plt.legend(loc=(1, 0))
content_copyCOPY

https://amueller.github.io/aml/02-supervised-learning/08-decision-trees.html