Lecture 20 (5/11/2022)¶
Announcements
Final projects!
TODO ERIK make next week’s lab include some work component for final project
Last time we covered:
ROC curves
Today’s agenda:
Common classification models
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
Common Classification Models¶
\(k\)-nearest neighbors
Logistic regression
Decision trees
Support Vector Machines (SVMs)
Other: naive Bayes, neural networks, discriminant analysis
Data: Predicting Heart Disease¶
From source:
A retrospective sample of males in a heart-disease high-risk region of the Western Cape, South Africa. There are roughly two controls per case of CHD. Many of the CHD positive men have undergone blood pressure reduction treatment and other programs to reduce their risk factors after their CHD event. In some cases the measurements were made after these treatments. These data are taken from a larger dataset, described in Rousseauw et al, 1983, South African Medical Journal.
sbp: systolic blood pressure
tobacco: cumulative tobacco (kg)
ldl: low densiity lipoprotein cholesterol
adiposity
famhist: family history of heart disease (Present, Absent)
typea: type-A behavior
obesity
alcohol: current alcohol consumption
age: age at onset
chd: response, coronary heart disease
data = pd.read_csv('https://web.stanford.edu/~hastie/ElemStatLearn/datasets/SAheart.data')
data
row.names | sbp | tobacco | ldl | adiposity | famhist | typea | obesity | alcohol | age | chd | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 160 | 12.00 | 5.73 | 23.11 | Present | 49 | 25.30 | 97.20 | 52 | 1 |
1 | 2 | 144 | 0.01 | 4.41 | 28.61 | Absent | 55 | 28.87 | 2.06 | 63 | 1 |
2 | 3 | 118 | 0.08 | 3.48 | 32.28 | Present | 52 | 29.14 | 3.81 | 46 | 0 |
3 | 4 | 170 | 7.50 | 6.41 | 38.03 | Present | 51 | 31.99 | 24.26 | 58 | 1 |
4 | 5 | 134 | 13.60 | 3.50 | 27.78 | Present | 60 | 25.99 | 57.34 | 49 | 1 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
457 | 459 | 214 | 0.40 | 5.98 | 31.72 | Absent | 64 | 28.45 | 0.00 | 58 | 0 |
458 | 460 | 182 | 4.20 | 4.41 | 32.10 | Absent | 52 | 28.61 | 18.72 | 52 | 1 |
459 | 461 | 108 | 3.00 | 1.59 | 15.23 | Absent | 40 | 20.09 | 26.64 | 55 | 0 |
460 | 462 | 118 | 5.40 | 11.61 | 30.79 | Absent | 64 | 27.35 | 23.97 | 40 | 0 |
461 | 463 | 132 | 0.00 | 4.82 | 33.41 | Present | 62 | 14.70 | 0.00 | 46 | 1 |
462 rows × 11 columns
Setting up our classifiers:
Let’s stick to just a single feature (age at onset) and see how different methods use this feature to predict the outcome label (CHD).
x_vals = np.array(data['age']).reshape(len(data), 1)
y_vals = np.array(data['chd'])
xtrain, xtest, ytrain, ytest = train_test_split(x_vals, y_vals, random_state = 1)
# sns.scatterplot(x = xtrain[:, 0], y = ytrain, alpha = .5)
sns.scatterplot(x = xtest[:, 0], y = ytest, alpha = .5)
plt.xlabel("Age")
plt.ylabel("CHD")
Text(0, 0.5, 'CHD')
Now, let’s get started!
Logistic Regression¶
How it works:
In linear regression, the relationship between our predictor \(x\) and our response variable \(y\) was:
\(y = \beta_0 + \beta_1 x\)
If our \(y\) values are all 0 or 1 (and assumed to be Bernoulli distributed with probability \(p\)), this approach doesn’t work very well:
It predicts values <0 and >1 for some inputs \(x\)
It doesn’t accomodate the fact that getting closer and closer to 1 gets harder and harder: one-unit changes in \(x\) may not have equal changes in \(p(y = 1)\).
So what do we do about this?
Instead, we postulate the following relationship between \(x\) and \(y\):
\(log \dfrac{p(y=1)}{p(y=0)} = \beta_0 + \beta_1 x\).
Every unit increase in \(x\) leads to a \(\beta_1\) increase in the log odds of \(y\) (or, every unit increase in \(x\) leads to a \(\beta_1\) multiplication of the odds of \(y\)).
This logit transform of our response variable \(y\) solves both of the problems with linear regression above.
However, the goal today isn’t to get into the nitty-gritty of logistic regression. Instead, let’s talk about how we use it as a classifier!
Classification
When we’ve fit a logistic regression to our data, we can output a probability \(p(y)\) for any given \(x\):
\(p(y) = \dfrac{e^{h(x)}}{1+ e^{h(x)}}\)
for \(h(x) = \beta_0 + \beta_1x\).
\(\dfrac{e^{h(x)}}{1+ e^{h(x)}}\) is the logistic function that maps from our \(x\) variable to \(p(y)\).
We can use this function as the basis for classification, where \(p(y)\) greater than a threshold \(T\) is given a particular label estimate \(\hat{y}\).
Fitting parameters
Even though logistic regression produces regression coefficients (intercept + slopes) similar to linear regression, these parameters are not estimated using the Ordinary Least Squares process we saw with linear regression. Instead, they are most often estimated using a more complicated process called Maximum Likelihood Estimation.
Logistic regression in python¶
You can read the scikit-learn documentation here.
# Import the LogisticRegression class
from sklearn.linear_model import LogisticRegression
# Initialize the logistic regression
log_reg = LogisticRegression(random_state = 1)
# Fit the model
log_reg.fit(X = xtrain, y = ytrain)
LogisticRegression(random_state=1)
What attributes do we get from this model fit?
log_reg.classes_
log_reg.intercept_ # What does this mean?
np.exp(log_reg.intercept_[0]) / (1 + np.exp(log_reg.intercept_[0]))
log_reg.coef_ # What does this mean?
np.exp(log_reg.coef_[0][0])
1.06682882157164
What functions does the model class give us?
binary_preds = log_reg.predict(xtest)
binary_preds
soft_preds = log_reg.predict_proba(xtest)
soft_preds
# soft_preds[:, 0] # probability of 0
# Accuracy of hard classification predictions
log_reg.score(X = xtest, y = ytest)
0.6637931034482759
How did we do?
# Let's show the actual test data
g = sns.scatterplot(x = xtest[:, 0], y = ytest, hue = ytest == binary_preds, alpha = .5)
# Now, let's plot our logistic regression curve
sns.lineplot(x = xtest[:, 0], y = soft_preds[:, 1])
# What is the "hard classification" boundary?
sns.lineplot(x = xtest[:, 0], y = binary_preds)
plt.axhline(0.5, linestyle = "--", color = "k") # this is what produces our classification boundary
g.set_xlabel("Age")
g.set_ylabel("CHD probability")
plt.legend(title = "Correct")
plt.show()
What are the true positive/negative and false positive/negatives above?
…
Understanding the regression
Let’s look at where the blue line above comes from.
Our logistic regression is formalized as follows:
For \(h(x) = \beta_0 + \beta_1x\),
\(p(y) = \dfrac{e^{h(x)}}{1+ e^{h(x)}}\)
# Let's implement the above transformation here
ypreds = np.exp(log_reg.intercept_ + log_reg.coef_*xtest) / (1 + np.exp(log_reg.intercept_ + log_reg.coef_*xtest))
# Now we can confirm that this worked
g = sns.lineplot(x = xtest[:, 0], y = ypreds[:, 0])
g.set_ylim(0, 1)
g.set_xlabel("Age")
g.set_ylabel("p(CHD)")
plt.show()
# Finally, let's look at the "linear" relationship underlying logistic regression
h = sns.lineplot(x = xtest[:, 0], y = np.log(ypreds[:, 0]/(1-ypreds[:, 0])))
h.set_xlabel("Age")
h.set_ylabel("Log odds of CHD")
plt.show()
Understanding the classification
Note, the classification boundary of 50% that we used based on our logistic function’s \(p(y)\) is somewhat arbitrary.
As with \(k\)-nearest neighbors, we can modify that classification threshold and generate an ROC curve over different thresholds.
from sklearn.metrics import roc_curve
# ROC curve
fpr, tpr, thresholds = roc_curve(
y_true = ytest,
y_score = log_reg.predict_proba(X = xtest)[:, 1],
pos_label = 1
)
sns.lineplot(x = fpr, y = tpr)
plt.axline(xy1 = (0, 0), slope = 1, c = "r")
plt.xlabel("FPR")
plt.ylabel("TPR")
plt.show()
Decision Trees¶
Decision trees are a form of classification that fits a model by generating successive rules based on the input feature values. These rules are optimized to try and classify the data as accurately as possible.
Above, the percentages are the percent of data points in each node and the proportions are the probability of survival (Source).
Take a second to interpret this.
Decision trees have the advantage of being super intuitive (like \(k\)-nearest neighbors, they’re similar to how people often think about classification).
There’s a great article about how they work here and a nice explanation of how the decision boundaries are identified here.
Decision tree classifiers in python¶
You can read the decision tree classifier documentation here.
# Import the DecisionTreeClassifier class
from sklearn.tree import DecisionTreeClassifier
# Initialize the decision tree classifier
dtree = DecisionTreeClassifier(random_state = 1)
# Fit the model
dtree.fit(X = xtrain, y = ytrain)
DecisionTreeClassifier(random_state=1)
from sklearn import tree
tree.plot_tree(dtree)
[Text(172.40380434782608, 209.07692307692307, 'X[0] <= 50.5\ngini = 0.453\nsamples = 346\nvalue = [226, 120]'),
Text(52.767391304347825, 192.35076923076923, 'X[0] <= 24.5\ngini = 0.339\nsamples = 217\nvalue = [170, 47]'),
Text(14.556521739130435, 175.62461538461537, 'X[0] <= 19.5\ngini = 0.038\nsamples = 51\nvalue = [50, 1]'),
Text(7.278260869565218, 158.89846153846153, 'gini = 0.0\nsamples = 38\nvalue = [38, 0]'),
Text(21.834782608695654, 158.89846153846153, 'X[0] <= 20.5\ngini = 0.142\nsamples = 13\nvalue = [12, 1]'),
Text(14.556521739130435, 142.1723076923077, 'gini = 0.278\nsamples = 6\nvalue = [5, 1]'),
Text(29.11304347826087, 142.1723076923077, 'gini = 0.0\nsamples = 7\nvalue = [7, 0]'),
Text(90.97826086956522, 175.62461538461537, 'X[0] <= 31.5\ngini = 0.401\nsamples = 166\nvalue = [120, 46]'),
Text(58.22608695652174, 158.89846153846153, 'X[0] <= 28.5\ngini = 0.305\nsamples = 32\nvalue = [26, 6]'),
Text(43.66956521739131, 142.1723076923077, 'X[0] <= 25.5\ngini = 0.415\nsamples = 17\nvalue = [12, 5]'),
Text(36.391304347826086, 125.44615384615385, 'gini = 0.5\nsamples = 2\nvalue = [1, 1]'),
Text(50.947826086956525, 125.44615384615385, 'X[0] <= 27.5\ngini = 0.391\nsamples = 15\nvalue = [11, 4]'),
Text(43.66956521739131, 108.72, 'X[0] <= 26.5\ngini = 0.375\nsamples = 8\nvalue = [6, 2]'),
Text(36.391304347826086, 91.99384615384615, 'gini = 0.375\nsamples = 4\nvalue = [3, 1]'),
Text(50.947826086956525, 91.99384615384615, 'gini = 0.375\nsamples = 4\nvalue = [3, 1]'),
Text(58.22608695652174, 108.72, 'gini = 0.408\nsamples = 7\nvalue = [5, 2]'),
Text(72.78260869565217, 142.1723076923077, 'X[0] <= 30.5\ngini = 0.124\nsamples = 15\nvalue = [14, 1]'),
Text(65.50434782608696, 125.44615384615385, 'gini = 0.0\nsamples = 8\nvalue = [8, 0]'),
Text(80.06086956521739, 125.44615384615385, 'gini = 0.245\nsamples = 7\nvalue = [6, 1]'),
Text(123.7304347826087, 158.89846153846153, 'X[0] <= 32.5\ngini = 0.419\nsamples = 134\nvalue = [94, 40]'),
Text(116.45217391304348, 142.1723076923077, 'gini = 0.494\nsamples = 9\nvalue = [5, 4]'),
Text(131.0086956521739, 142.1723076923077, 'X[0] <= 38.5\ngini = 0.41\nsamples = 125\nvalue = [89, 36]'),
Text(94.61739130434783, 125.44615384615385, 'X[0] <= 35.5\ngini = 0.342\nsamples = 32\nvalue = [25, 7]'),
Text(80.06086956521739, 108.72, 'X[0] <= 34.5\ngini = 0.298\nsamples = 11\nvalue = [9, 2]'),
Text(72.78260869565217, 91.99384615384615, 'X[0] <= 33.5\ngini = 0.32\nsamples = 10\nvalue = [8, 2]'),
Text(65.50434782608696, 75.2676923076923, 'gini = 0.278\nsamples = 6\nvalue = [5, 1]'),
Text(80.06086956521739, 75.2676923076923, 'gini = 0.375\nsamples = 4\nvalue = [3, 1]'),
Text(87.33913043478262, 91.99384615384615, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'),
Text(109.17391304347827, 108.72, 'X[0] <= 36.5\ngini = 0.363\nsamples = 21\nvalue = [16, 5]'),
Text(101.89565217391305, 91.99384615384615, 'gini = 0.444\nsamples = 3\nvalue = [2, 1]'),
Text(116.45217391304348, 91.99384615384615, 'X[0] <= 37.5\ngini = 0.346\nsamples = 18\nvalue = [14, 4]'),
Text(109.17391304347827, 75.2676923076923, 'gini = 0.278\nsamples = 6\nvalue = [5, 1]'),
Text(123.7304347826087, 75.2676923076923, 'gini = 0.375\nsamples = 12\nvalue = [9, 3]'),
Text(167.4, 125.44615384615385, 'X[0] <= 43.5\ngini = 0.429\nsamples = 93\nvalue = [64, 29]'),
Text(152.84347826086957, 108.72, 'X[0] <= 42.5\ngini = 0.473\nsamples = 39\nvalue = [24, 15]'),
Text(145.56521739130434, 91.99384615384615, 'X[0] <= 39.5\ngini = 0.451\nsamples = 32\nvalue = [21, 11]'),
Text(138.28695652173914, 75.2676923076923, 'gini = 0.49\nsamples = 7\nvalue = [4, 3]'),
Text(152.84347826086957, 75.2676923076923, 'X[0] <= 40.5\ngini = 0.435\nsamples = 25\nvalue = [17, 8]'),
Text(145.56521739130434, 58.541538461538465, 'gini = 0.42\nsamples = 10\nvalue = [7, 3]'),
Text(160.12173913043478, 58.541538461538465, 'X[0] <= 41.5\ngini = 0.444\nsamples = 15\nvalue = [10, 5]'),
Text(152.84347826086957, 41.81538461538463, 'gini = 0.444\nsamples = 6\nvalue = [4, 2]'),
Text(167.4, 41.81538461538463, 'gini = 0.444\nsamples = 9\nvalue = [6, 3]'),
Text(160.12173913043478, 91.99384615384615, 'gini = 0.49\nsamples = 7\nvalue = [3, 4]'),
Text(181.95652173913044, 108.72, 'X[0] <= 44.5\ngini = 0.384\nsamples = 54\nvalue = [40, 14]'),
Text(174.67826086956524, 91.99384615384615, 'gini = 0.0\nsamples = 7\nvalue = [7, 0]'),
Text(189.23478260869567, 91.99384615384615, 'X[0] <= 45.5\ngini = 0.418\nsamples = 47\nvalue = [33, 14]'),
Text(181.95652173913044, 75.2676923076923, 'gini = 0.5\nsamples = 8\nvalue = [4, 4]'),
Text(196.51304347826087, 75.2676923076923, 'X[0] <= 49.5\ngini = 0.381\nsamples = 39\nvalue = [29, 10]'),
Text(189.23478260869567, 58.541538461538465, 'X[0] <= 46.5\ngini = 0.35\nsamples = 31\nvalue = [24, 7]'),
Text(181.95652173913044, 41.81538461538463, 'gini = 0.444\nsamples = 6\nvalue = [4, 2]'),
Text(196.51304347826087, 41.81538461538463, 'X[0] <= 47.5\ngini = 0.32\nsamples = 25\nvalue = [20, 5]'),
Text(189.23478260869567, 25.089230769230767, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'),
Text(203.7913043478261, 25.089230769230767, 'X[0] <= 48.5\ngini = 0.33\nsamples = 24\nvalue = [19, 5]'),
Text(196.51304347826087, 8.363076923076932, 'gini = 0.355\nsamples = 13\nvalue = [10, 3]'),
Text(211.0695652173913, 8.363076923076932, 'gini = 0.298\nsamples = 11\nvalue = [9, 2]'),
Text(203.7913043478261, 58.541538461538465, 'gini = 0.469\nsamples = 8\nvalue = [5, 3]'),
Text(292.04021739130434, 192.35076923076923, 'X[0] <= 59.5\ngini = 0.491\nsamples = 129\nvalue = [56, 73]'),
Text(271.1152173913043, 175.62461538461537, 'X[0] <= 57.5\ngini = 0.472\nsamples = 81\nvalue = [31, 50]'),
Text(251.1, 158.89846153846153, 'X[0] <= 55.5\ngini = 0.494\nsamples = 56\nvalue = [25, 31]'),
Text(232.90434782608696, 142.1723076923077, 'X[0] <= 53.5\ngini = 0.476\nsamples = 46\nvalue = [18, 28]'),
Text(218.34782608695653, 125.44615384615385, 'X[0] <= 52.5\ngini = 0.497\nsamples = 26\nvalue = [12, 14]'),
Text(211.0695652173913, 108.72, 'X[0] <= 51.5\ngini = 0.48\nsamples = 15\nvalue = [6, 9]'),
Text(203.7913043478261, 91.99384615384615, 'gini = 0.49\nsamples = 7\nvalue = [3, 4]'),
Text(218.34782608695653, 91.99384615384615, 'gini = 0.469\nsamples = 8\nvalue = [3, 5]'),
Text(225.62608695652173, 108.72, 'gini = 0.496\nsamples = 11\nvalue = [6, 5]'),
Text(247.4608695652174, 125.44615384615385, 'X[0] <= 54.5\ngini = 0.42\nsamples = 20\nvalue = [6, 14]'),
Text(240.1826086956522, 108.72, 'gini = 0.278\nsamples = 6\nvalue = [1, 5]'),
Text(254.73913043478262, 108.72, 'gini = 0.459\nsamples = 14\nvalue = [5, 9]'),
Text(269.295652173913, 142.1723076923077, 'X[0] <= 56.5\ngini = 0.42\nsamples = 10\nvalue = [7, 3]'),
Text(262.0173913043478, 125.44615384615385, 'gini = 0.48\nsamples = 5\nvalue = [3, 2]'),
Text(276.5739130434783, 125.44615384615385, 'gini = 0.32\nsamples = 5\nvalue = [4, 1]'),
Text(291.1304347826087, 158.89846153846153, 'X[0] <= 58.5\ngini = 0.365\nsamples = 25\nvalue = [6, 19]'),
Text(283.8521739130435, 142.1723076923077, 'gini = 0.375\nsamples = 12\nvalue = [3, 9]'),
Text(298.40869565217395, 142.1723076923077, 'gini = 0.355\nsamples = 13\nvalue = [3, 10]'),
Text(312.96521739130435, 175.62461538461537, 'X[0] <= 60.5\ngini = 0.499\nsamples = 48\nvalue = [25, 23]'),
Text(305.68695652173915, 158.89846153846153, 'gini = 0.375\nsamples = 12\nvalue = [9, 3]'),
Text(320.24347826086955, 158.89846153846153, 'X[0] <= 63.5\ngini = 0.494\nsamples = 36\nvalue = [16, 20]'),
Text(312.96521739130435, 142.1723076923077, 'X[0] <= 62.5\ngini = 0.473\nsamples = 26\nvalue = [10, 16]'),
Text(305.68695652173915, 125.44615384615385, 'X[0] <= 61.5\ngini = 0.49\nsamples = 21\nvalue = [9, 12]'),
Text(298.40869565217395, 108.72, 'gini = 0.444\nsamples = 12\nvalue = [4, 8]'),
Text(312.96521739130435, 108.72, 'gini = 0.494\nsamples = 9\nvalue = [5, 4]'),
Text(320.24347826086955, 125.44615384615385, 'gini = 0.32\nsamples = 5\nvalue = [1, 4]'),
Text(327.5217391304348, 142.1723076923077, 'gini = 0.48\nsamples = 10\nvalue = [6, 4]')]
Whoa.
Decision trees can overfit data a lot if they aren’t constrained.
dtree.score(X = xtrain, y = ytrain)
dtree.score(X = xtest, y = ytest)
# Seems like we're overfitting
0.646551724137931
Let’s try this again…
dtree = DecisionTreeClassifier(
# how many layers our decision tree should have (toggle between 1 and 2 and see how this impacts results)
max_depth = 1,
random_state = 1
)
# Fit the model
dtree.fit(X = xtrain, y = ytrain)
DecisionTreeClassifier(max_depth=1, random_state=1)
tree.plot_tree(dtree,
feature_names = ['Age'],
class_names = ['No CHD', 'CHD'],
filled = True
)
[Text(167.4, 163.07999999999998, 'Age <= 50.5\ngini = 0.453\nsamples = 346\nvalue = [226, 120]\nclass = No CHD'),
Text(83.7, 54.360000000000014, 'gini = 0.339\nsamples = 217\nvalue = [170, 47]\nclass = No CHD'),
Text(251.10000000000002, 54.360000000000014, 'gini = 0.491\nsamples = 129\nvalue = [56, 73]\nclass = CHD')]
What’s going on here?
Age <= 50.5
: This is the “rule” being used to define leaves on either side of the tree (“No” -> left, “Yes” -> right)gini = 0.453
: This refers to the “Gini impurity” of the node. Gini impurity is the loss function used to fit this tree (optimal = 0) (more on this here)samples = 346
: This is the number of samples in the group that the node is dividingvalue = [226, 120]
: This is the number of training values on the left (values[0]
) and the right (values[1]
) of the node
NOTE: With a depth of 1, at the very bottom, we have:
170 people were correctly classified as “No CHD” with this rule (true negatives)
47 people were classified as “No CHD” with this rule incorrectly (false negatives)
56 people were classified as “CHD” with this rule incorrectly (false positives)
73 people were classified as “CHD” with this rule correctly (true positives)
Like other classifiers, the decision tree classifier lets us predict values and has functions for assessing prediction accuracy.
# Accuracy on the data
dtree.score(X = xtrain, y = ytrain)
dtree.score(X = xtest, y = ytest)
0.6810344827586207
ypreds = dtree.predict(X = xtest)
ypreds
# Test `score` above
sum(ypreds == ytest) / len(ypreds)
0.6810344827586207
# The "soft classification" probabilities are just the fraction of training samples for the "true" label
# in the leaf where this test item ended up
ypreds_soft = dtree.predict_proba(X = xtest)
ypreds_soft
array([[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.43410853, 0.56589147],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.43410853, 0.56589147],
[0.43410853, 0.56589147],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.78341014, 0.21658986],
[0.43410853, 0.56589147]])
We can use the predictions as the basis for better understanding what the tree is doing:
# This reveals the cutoff(s) chosen by our decision tree!
train_preds = dtree.predict(X = xtrain)
g = sns.scatterplot(x = xtrain[:, 0], y = ytrain, hue = ytrain == train_preds, alpha = .25)
# These are the decision boundaries in the tree. You can see how they segment our data into more accurate predictions
g.axvline(50.5)
g.axvline(59.5)
g.axvline(24.5)
<matplotlib.lines.Line2D at 0x7fc519edbb80>
### YOUR CODE HERE
# Make a similar graph to the above with the test data
We can also draw on the same resources that we talked about for assessing our \(k\)-nearest neighbors classifier:
Accuracy / F1 score
ROC curves
from sklearn.metrics import accuracy_score, f1_score
# Test accuracy
accuracy_score(y_true = ytest, y_pred = dtree.predict(X = xtest))
# Test F1 score
f1_score(y_true = ytest,
y_pred = dtree.predict(X = xtest),
labels = [0, 1],
pos_label = 1
)
0.5542168674698795
from sklearn.metrics import roc_curve
# ROC curve
fpr, tpr, thresholds = roc_curve(
y_true = ytest,
y_score = dtree.predict_proba(X = xtest)[:, 1],
pos_label = 1
)
sns.lineplot(x = fpr, y = tpr)
plt.axline(xy1 = (0, 0), slope = 1, c = "r")
plt.xlabel("FPR")
plt.ylabel("TPR")
Text(0, 0.5, 'TPR')
Support Vector Machines (SVMs)¶
Support Vector Machines work by trying to find a line or plane (usually in a high-dimensional space) that maximally separates the training labels in that space.
The intuition for this is relatively straightforward but the implementation can get complicated!
In the plot below, the linear funtion \(h_3(x_1, x_2)\) is the best way to separate our training data because it maximizes the margin on either side of the line.
SVMs try to find the equivalent of \(h_3\) given some training data. This separator can be defined by the closest points in the data to the line; these are called the “support vectors”. Finding the best separator usually requires mapping the training data into a high-dimensional space where it can be effectively separated.
(Source)
SVMs in python¶
The documentation for SVMs in scikit-learn is here.
from sklearn.svm import SVC
svm = SVC()
svm.fit(xtrain, ytrain)
SVC()
In the case of SVMs, there are class attributes that help you recover the separator that was fit.
We won’t get into these but if you’re interested in learning more it’s good to know about!
# svm.intercept_
# svm.coef_ # only for 'linear' kernel
# svm.support_vectors_
# For example, we can view the items in the training set that formed the support vector
sns.scatterplot(x = xtrain[:, 0], y = ytrain, alpha = .25)
plt.title("Training data")
plt.show()
sns.scatterplot(x = xtrain[svm.support_][:, 0], y = ytrain[svm.support_], alpha = .25)
plt.title("Support vectors")
plt.show()
The SVM class has a score
function that returns the accuracy of a test set, plus prediction functions.
# Percent of correct classifications
svm.score(X = xtrain, y = ytrain)
svm.score(X = xtest, y = ytest)
0.646551724137931
ypreds = svm.predict(X = xtest)
ypreds
array([0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0,
0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1,
0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1,
0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0,
0, 1, 0, 0, 0, 1])
However, soft prediction requires configuring the initial model to do soft classification (by default, SVMs are made to only do hard classification).
svm_soft = SVC(probability = True) # indicate that you want the SVM to do soft classification
svm_soft.fit(X = xtrain, y = ytrain)
ypreds_soft = svm_soft.predict_proba(X = xtest)
ypreds_soft
array([[0.76488586, 0.23511414],
[0.7660128 , 0.2339872 ],
[0.71339712, 0.28660288],
[0.7655425 , 0.2344575 ],
[0.39508517, 0.60491483],
[0.76490086, 0.23509914],
[0.77019528, 0.22980472],
[0.48781311, 0.51218689],
[0.7660128 , 0.2339872 ],
[0.40023559, 0.59976441],
[0.43233607, 0.56766393],
[0.76532763, 0.23467237],
[0.76490086, 0.23509914],
[0.4023721 , 0.5976279 ],
[0.76335536, 0.23664464],
[0.71339712, 0.28660288],
[0.76582909, 0.23417091],
[0.5219035 , 0.4780965 ],
[0.76335536, 0.23664464],
[0.77008692, 0.22991308],
[0.76491028, 0.23508972],
[0.76491028, 0.23508972],
[0.74676154, 0.25323846],
[0.7660128 , 0.2339872 ],
[0.76488586, 0.23511414],
[0.41546636, 0.58453364],
[0.7660128 , 0.2339872 ],
[0.75722289, 0.24277711],
[0.41546636, 0.58453364],
[0.39508517, 0.60491483],
[0.45754721, 0.54245279],
[0.76420384, 0.23579616],
[0.7660128 , 0.2339872 ],
[0.76271432, 0.23728568],
[0.76420384, 0.23579616],
[0.59422346, 0.40577654],
[0.76493222, 0.23506778],
[0.41303513, 0.58696487],
[0.76413905, 0.23586095],
[0.59422346, 0.40577654],
[0.76493222, 0.23506778],
[0.76335536, 0.23664464],
[0.76828996, 0.23171004],
[0.55808055, 0.44191945],
[0.76419446, 0.23580554],
[0.7655425 , 0.2344575 ],
[0.43233607, 0.56766393],
[0.76493222, 0.23506778],
[0.41303513, 0.58696487],
[0.76488586, 0.23511414],
[0.75722289, 0.24277711],
[0.41546636, 0.58453364],
[0.48781311, 0.51218689],
[0.4023721 , 0.5976279 ],
[0.65891565, 0.34108435],
[0.76758491, 0.23241509],
[0.76582909, 0.23417091],
[0.71339712, 0.28660288],
[0.41303513, 0.58696487],
[0.45754721, 0.54245279],
[0.77019528, 0.22980472],
[0.76582909, 0.23417091],
[0.62831079, 0.37168921],
[0.76532763, 0.23467237],
[0.55808055, 0.44191945],
[0.40023559, 0.59976441],
[0.75722289, 0.24277711],
[0.75722289, 0.24277711],
[0.76758491, 0.23241509],
[0.76420384, 0.23579616],
[0.76493222, 0.23506778],
[0.76493222, 0.23506778],
[0.77019528, 0.22980472],
[0.76595219, 0.23404781],
[0.76468915, 0.23531085],
[0.76444996, 0.23555004],
[0.764262 , 0.235738 ],
[0.40023559, 0.59976441],
[0.76490086, 0.23509914],
[0.7655425 , 0.2344575 ],
[0.48781311, 0.51218689],
[0.55808055, 0.44191945],
[0.76413905, 0.23586095],
[0.43233607, 0.56766393],
[0.5219035 , 0.4780965 ],
[0.76468915, 0.23531085],
[0.39425539, 0.60574461],
[0.48781311, 0.51218689],
[0.74676154, 0.25323846],
[0.45754721, 0.54245279],
[0.76271432, 0.23728568],
[0.76572784, 0.23427216],
[0.39425539, 0.60574461],
[0.77019528, 0.22980472],
[0.39508517, 0.60491483],
[0.76918779, 0.23081221],
[0.7660128 , 0.2339872 ],
[0.76918779, 0.23081221],
[0.7660128 , 0.2339872 ],
[0.76488586, 0.23511414],
[0.76918779, 0.23081221],
[0.62831079, 0.37168921],
[0.43233607, 0.56766393],
[0.76420384, 0.23579616],
[0.43233607, 0.56766393],
[0.4023721 , 0.5976279 ],
[0.76420384, 0.23579616],
[0.76420384, 0.23579616],
[0.55808055, 0.44191945],
[0.59422346, 0.40577654],
[0.7655425 , 0.2344575 ],
[0.39425539, 0.60574461],
[0.7655425 , 0.2344575 ],
[0.76918779, 0.23081221],
[0.75722289, 0.24277711],
[0.39425539, 0.60574461]])
Classifier Wrap-Up¶
This is just a sample of what’s out there!
There are a number of other common classifiers you should take a look at if you’re interested:
Naive Bayes (here)
Neural networks (here)
Random forests (here) (related to decision trees)
Gradient boosted trees (here)
…
The main goal of this lecture is to show you some of the creative ways that people solve classification problems and how the scikit-learn library supports these solutions.
This should empower you to go off and try some of these other ones on your own!