Lecture 20 (5/11/2022)

Announcements

Last time we covered:

  • ROC curves

Today’s agenda:

  • Common classification models

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split

Common Classification Models

  • \(k\)-nearest neighbors

  • Logistic regression

  • Decision trees

  • Support Vector Machines (SVMs)

  • Other: naive Bayes, neural networks, discriminant analysis

Data: Predicting Heart Disease

From source:

A retrospective sample of males in a heart-disease high-risk region of the Western Cape, South Africa. There are roughly two controls per case of CHD. Many of the CHD positive men have undergone blood pressure reduction treatment and other programs to reduce their risk factors after their CHD event. In some cases the measurements were made after these treatments. These data are taken from a larger dataset, described in Rousseauw et al, 1983, South African Medical Journal.

  • sbp: systolic blood pressure

  • tobacco: cumulative tobacco (kg)

  • ldl: low densiity lipoprotein cholesterol

  • adiposity

  • famhist: family history of heart disease (Present, Absent)

  • typea: type-A behavior

  • obesity

  • alcohol: current alcohol consumption

  • age: age at onset

  • chd: response, coronary heart disease

data = pd.read_csv('https://web.stanford.edu/~hastie/ElemStatLearn/datasets/SAheart.data')

data
row.names sbp tobacco ldl adiposity famhist typea obesity alcohol age chd
0 1 160 12.00 5.73 23.11 Present 49 25.30 97.20 52 1
1 2 144 0.01 4.41 28.61 Absent 55 28.87 2.06 63 1
2 3 118 0.08 3.48 32.28 Present 52 29.14 3.81 46 0
3 4 170 7.50 6.41 38.03 Present 51 31.99 24.26 58 1
4 5 134 13.60 3.50 27.78 Present 60 25.99 57.34 49 1
... ... ... ... ... ... ... ... ... ... ... ...
457 459 214 0.40 5.98 31.72 Absent 64 28.45 0.00 58 0
458 460 182 4.20 4.41 32.10 Absent 52 28.61 18.72 52 1
459 461 108 3.00 1.59 15.23 Absent 40 20.09 26.64 55 0
460 462 118 5.40 11.61 30.79 Absent 64 27.35 23.97 40 0
461 463 132 0.00 4.82 33.41 Present 62 14.70 0.00 46 1

462 rows × 11 columns

Setting up our classifiers:

Let’s stick to just a single feature (age at onset) and see how different methods use this feature to predict the outcome label (CHD).

x_vals = np.array(data['age']).reshape(len(data), 1)
y_vals = np.array(data['chd'])

xtrain, xtest, ytrain, ytest = train_test_split(x_vals, y_vals, random_state = 1)

Now, let’s get started!

Logistic Regression

How it works:

In linear regression, the relationship between our predictor \(x\) and our response variable \(y\) was:

\(y = \beta_0 + \beta_1 x\)

If our \(y\) values are all 0 or 1 (and assumed to be Bernoulli distributed with probability \(p\)), this approach doesn’t work very well:

  1. It predicts values <0 and >1 for some inputs \(x\)

  2. It doesn’t accomodate the fact that getting closer and closer to 1 gets harder and harder: one-unit changes in \(x\) may not have equal changes in \(p(y = 1)\).

So what do we do about this?

Instead, we postulate the following relationship between \(x\) and \(y\):

\(log \dfrac{p(y=1)}{p(y=0)} = \beta_0 + \beta_1 x\).

Every unit increase in \(x\) leads to a \(\beta_1\) increase in the log odds of \(y\) (or, every unit increase in \(x\) leads to a \(\beta_1\) multiplication of the odds of \(y\)).

This logit transform of our response variable \(y\) solves both of the problems with linear regression above.

However, the goal today isn’t to get into the nitty-gritty of logistic regression. Instead, let’s talk about how we use it as a classifier!

Classification

When we’ve fit a logistic regression to our data, we can output a probability \(p(y)\) for any given \(x\):

\(p(y) = \dfrac{e^{h(x)}}{1+ e^{h(x)}}\)

for \(h(x) = \beta_0 + \beta_1x\).

\(\dfrac{e^{h(x)}}{1+ e^{h(x)}}\) is the logistic function that maps from our \(x\) variable to \(p(y)\).

We can use this function as the basis for classification, where \(p(y)\) greater than a threshold \(T\) is given a particular label estimate \(\hat{y}\).

Fitting parameters

Even though logistic regression produces regression coefficients (intercept + slopes) similar to linear regression, these parameters are not estimated using the Ordinary Least Squares process we saw with linear regression. Instead, they are most often estimated using a more complicated process called Maximum Likelihood Estimation.

Logistic regression in python

You can read the scikit-learn documentation here.

# Import the LogisticRegression class
from sklearn.linear_model import LogisticRegression

# Initialize the logistic regression
log_reg = LogisticRegression(random_state = 1)

# Fit the model
log_reg.fit(X = xtrain, y = ytrain)
LogisticRegression(random_state=1)

What attributes do we get from this model fit?

log_reg.classes_

log_reg.intercept_ # What does this mean?
# np.exp(log_reg.intercept_[0]) / (1 + np.exp(log_reg.intercept_[0]))

log_reg.coef_ # What does this mean?
# np.exp(log_reg.coef_[0][0])
array([[0.06469053]])

What functions does the model class give us?

binary_preds = log_reg.predict(xtest)
binary_preds

soft_preds = log_reg.predict_proba(xtest)
soft_preds
# soft_preds[:, 0] # probability of 0


# Accuracy of hard classification predictions
log_reg.score(X = xtest, y = ytest) 
0.6637931034482759

How did we do?

# Let's show the actual test data
g = sns.scatterplot(x = xtest[:, 0], y = ytest, hue = ytest == binary_preds)

# Now, let's plot our logistic regression curve
sns.lineplot(x = xtest[:, 0], y = soft_preds[:, 1])

# What is the "hard classification" boundary?
sns.lineplot(x = xtest[:, 0], y = binary_preds)
plt.axhline(0.5, linestyle = "--", color = "k") # this is what produces our classification boundary


g.set_xlabel("Age")
g.set_ylabel("CDH probability")
plt.legend(title = "Correct")

plt.show()
../../_images/Lecture_20-pre_18_0.png

Let’s look at where the blue line above comes from.

Our logistic regression is formalized as follows:

For \(h(x) = \beta_0 + \beta_1x\),

\(p(y) = \dfrac{e^{h(x)}}{1+ e^{h(x)}}\)

# Let's implement the above transformation here
ypreds = np.exp(log_reg.intercept_ + log_reg.coef_*xtest) / (1 + np.exp(log_reg.intercept_ + log_reg.coef_*xtest))

# Now we can confirm that this worked
g = sns.lineplot(x = xtest[:, 0], y = ypreds[:, 0])
g.set_ylim(0, 1)
g.set_xlabel("Age")
g.set_ylabel("p(CDH)")
plt.show()

# Finally, let's look at the "linear" relationship underlying logistic regression
h = sns.lineplot(x = xtest[:, 0], y = np.log(ypreds[:, 0]/(1-ypreds[:, 0])))
h.set_xlabel("Age")
h.set_ylabel("Log odds of CDH")
plt.show()
../../_images/Lecture_20-pre_20_0.png ../../_images/Lecture_20-pre_20_1.png

Decision Trees

Decision trees are a form of classification that fits a model by generating successive rules based on the input feature values. These rules are optimized to try and classify the data as accurately as possible.

decision_tree

Above, the percentages are the percent of data points in each node and the proportions are the probability of survival (Source).

Take a second to interpret this.

Decision trees have the advantage of being super intuitive (like \(k\)-nearest neighbors, they’re similar to how people often think about classification).

There’s a great article about how they work here and a nice explanation of how the decision boundaries are identified here.

Decision tree classifiers in python

You can read the decision tree classifier documentation here.

# Import the DecisionTreeClassifier class
from sklearn.tree import DecisionTreeClassifier


# Initialize the decision tree classifier
dtree = DecisionTreeClassifier(random_state = 1)

# Fit the model
dtree.fit(X = xtrain, y = ytrain)
DecisionTreeClassifier(random_state=1)
from sklearn import tree

tree.plot_tree(dtree)
[Text(172.40380434782608, 209.07692307692307, 'X[0] <= 50.5\ngini = 0.453\nsamples = 346\nvalue = [226, 120]'),
 Text(52.767391304347825, 192.35076923076923, 'X[0] <= 24.5\ngini = 0.339\nsamples = 217\nvalue = [170, 47]'),
 Text(14.556521739130435, 175.62461538461537, 'X[0] <= 19.5\ngini = 0.038\nsamples = 51\nvalue = [50, 1]'),
 Text(7.278260869565218, 158.89846153846153, 'gini = 0.0\nsamples = 38\nvalue = [38, 0]'),
 Text(21.834782608695654, 158.89846153846153, 'X[0] <= 20.5\ngini = 0.142\nsamples = 13\nvalue = [12, 1]'),
 Text(14.556521739130435, 142.1723076923077, 'gini = 0.278\nsamples = 6\nvalue = [5, 1]'),
 Text(29.11304347826087, 142.1723076923077, 'gini = 0.0\nsamples = 7\nvalue = [7, 0]'),
 Text(90.97826086956522, 175.62461538461537, 'X[0] <= 31.5\ngini = 0.401\nsamples = 166\nvalue = [120, 46]'),
 Text(58.22608695652174, 158.89846153846153, 'X[0] <= 28.5\ngini = 0.305\nsamples = 32\nvalue = [26, 6]'),
 Text(43.66956521739131, 142.1723076923077, 'X[0] <= 25.5\ngini = 0.415\nsamples = 17\nvalue = [12, 5]'),
 Text(36.391304347826086, 125.44615384615385, 'gini = 0.5\nsamples = 2\nvalue = [1, 1]'),
 Text(50.947826086956525, 125.44615384615385, 'X[0] <= 27.5\ngini = 0.391\nsamples = 15\nvalue = [11, 4]'),
 Text(43.66956521739131, 108.72, 'X[0] <= 26.5\ngini = 0.375\nsamples = 8\nvalue = [6, 2]'),
 Text(36.391304347826086, 91.99384615384615, 'gini = 0.375\nsamples = 4\nvalue = [3, 1]'),
 Text(50.947826086956525, 91.99384615384615, 'gini = 0.375\nsamples = 4\nvalue = [3, 1]'),
 Text(58.22608695652174, 108.72, 'gini = 0.408\nsamples = 7\nvalue = [5, 2]'),
 Text(72.78260869565217, 142.1723076923077, 'X[0] <= 30.5\ngini = 0.124\nsamples = 15\nvalue = [14, 1]'),
 Text(65.50434782608696, 125.44615384615385, 'gini = 0.0\nsamples = 8\nvalue = [8, 0]'),
 Text(80.06086956521739, 125.44615384615385, 'gini = 0.245\nsamples = 7\nvalue = [6, 1]'),
 Text(123.7304347826087, 158.89846153846153, 'X[0] <= 32.5\ngini = 0.419\nsamples = 134\nvalue = [94, 40]'),
 Text(116.45217391304348, 142.1723076923077, 'gini = 0.494\nsamples = 9\nvalue = [5, 4]'),
 Text(131.0086956521739, 142.1723076923077, 'X[0] <= 38.5\ngini = 0.41\nsamples = 125\nvalue = [89, 36]'),
 Text(94.61739130434783, 125.44615384615385, 'X[0] <= 35.5\ngini = 0.342\nsamples = 32\nvalue = [25, 7]'),
 Text(80.06086956521739, 108.72, 'X[0] <= 34.5\ngini = 0.298\nsamples = 11\nvalue = [9, 2]'),
 Text(72.78260869565217, 91.99384615384615, 'X[0] <= 33.5\ngini = 0.32\nsamples = 10\nvalue = [8, 2]'),
 Text(65.50434782608696, 75.2676923076923, 'gini = 0.278\nsamples = 6\nvalue = [5, 1]'),
 Text(80.06086956521739, 75.2676923076923, 'gini = 0.375\nsamples = 4\nvalue = [3, 1]'),
 Text(87.33913043478262, 91.99384615384615, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'),
 Text(109.17391304347827, 108.72, 'X[0] <= 36.5\ngini = 0.363\nsamples = 21\nvalue = [16, 5]'),
 Text(101.89565217391305, 91.99384615384615, 'gini = 0.444\nsamples = 3\nvalue = [2, 1]'),
 Text(116.45217391304348, 91.99384615384615, 'X[0] <= 37.5\ngini = 0.346\nsamples = 18\nvalue = [14, 4]'),
 Text(109.17391304347827, 75.2676923076923, 'gini = 0.278\nsamples = 6\nvalue = [5, 1]'),
 Text(123.7304347826087, 75.2676923076923, 'gini = 0.375\nsamples = 12\nvalue = [9, 3]'),
 Text(167.4, 125.44615384615385, 'X[0] <= 43.5\ngini = 0.429\nsamples = 93\nvalue = [64, 29]'),
 Text(152.84347826086957, 108.72, 'X[0] <= 42.5\ngini = 0.473\nsamples = 39\nvalue = [24, 15]'),
 Text(145.56521739130434, 91.99384615384615, 'X[0] <= 39.5\ngini = 0.451\nsamples = 32\nvalue = [21, 11]'),
 Text(138.28695652173914, 75.2676923076923, 'gini = 0.49\nsamples = 7\nvalue = [4, 3]'),
 Text(152.84347826086957, 75.2676923076923, 'X[0] <= 40.5\ngini = 0.435\nsamples = 25\nvalue = [17, 8]'),
 Text(145.56521739130434, 58.541538461538465, 'gini = 0.42\nsamples = 10\nvalue = [7, 3]'),
 Text(160.12173913043478, 58.541538461538465, 'X[0] <= 41.5\ngini = 0.444\nsamples = 15\nvalue = [10, 5]'),
 Text(152.84347826086957, 41.81538461538463, 'gini = 0.444\nsamples = 6\nvalue = [4, 2]'),
 Text(167.4, 41.81538461538463, 'gini = 0.444\nsamples = 9\nvalue = [6, 3]'),
 Text(160.12173913043478, 91.99384615384615, 'gini = 0.49\nsamples = 7\nvalue = [3, 4]'),
 Text(181.95652173913044, 108.72, 'X[0] <= 44.5\ngini = 0.384\nsamples = 54\nvalue = [40, 14]'),
 Text(174.67826086956524, 91.99384615384615, 'gini = 0.0\nsamples = 7\nvalue = [7, 0]'),
 Text(189.23478260869567, 91.99384615384615, 'X[0] <= 45.5\ngini = 0.418\nsamples = 47\nvalue = [33, 14]'),
 Text(181.95652173913044, 75.2676923076923, 'gini = 0.5\nsamples = 8\nvalue = [4, 4]'),
 Text(196.51304347826087, 75.2676923076923, 'X[0] <= 49.5\ngini = 0.381\nsamples = 39\nvalue = [29, 10]'),
 Text(189.23478260869567, 58.541538461538465, 'X[0] <= 46.5\ngini = 0.35\nsamples = 31\nvalue = [24, 7]'),
 Text(181.95652173913044, 41.81538461538463, 'gini = 0.444\nsamples = 6\nvalue = [4, 2]'),
 Text(196.51304347826087, 41.81538461538463, 'X[0] <= 47.5\ngini = 0.32\nsamples = 25\nvalue = [20, 5]'),
 Text(189.23478260869567, 25.089230769230767, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'),
 Text(203.7913043478261, 25.089230769230767, 'X[0] <= 48.5\ngini = 0.33\nsamples = 24\nvalue = [19, 5]'),
 Text(196.51304347826087, 8.363076923076932, 'gini = 0.355\nsamples = 13\nvalue = [10, 3]'),
 Text(211.0695652173913, 8.363076923076932, 'gini = 0.298\nsamples = 11\nvalue = [9, 2]'),
 Text(203.7913043478261, 58.541538461538465, 'gini = 0.469\nsamples = 8\nvalue = [5, 3]'),
 Text(292.04021739130434, 192.35076923076923, 'X[0] <= 59.5\ngini = 0.491\nsamples = 129\nvalue = [56, 73]'),
 Text(271.1152173913043, 175.62461538461537, 'X[0] <= 57.5\ngini = 0.472\nsamples = 81\nvalue = [31, 50]'),
 Text(251.1, 158.89846153846153, 'X[0] <= 55.5\ngini = 0.494\nsamples = 56\nvalue = [25, 31]'),
 Text(232.90434782608696, 142.1723076923077, 'X[0] <= 53.5\ngini = 0.476\nsamples = 46\nvalue = [18, 28]'),
 Text(218.34782608695653, 125.44615384615385, 'X[0] <= 52.5\ngini = 0.497\nsamples = 26\nvalue = [12, 14]'),
 Text(211.0695652173913, 108.72, 'X[0] <= 51.5\ngini = 0.48\nsamples = 15\nvalue = [6, 9]'),
 Text(203.7913043478261, 91.99384615384615, 'gini = 0.49\nsamples = 7\nvalue = [3, 4]'),
 Text(218.34782608695653, 91.99384615384615, 'gini = 0.469\nsamples = 8\nvalue = [3, 5]'),
 Text(225.62608695652173, 108.72, 'gini = 0.496\nsamples = 11\nvalue = [6, 5]'),
 Text(247.4608695652174, 125.44615384615385, 'X[0] <= 54.5\ngini = 0.42\nsamples = 20\nvalue = [6, 14]'),
 Text(240.1826086956522, 108.72, 'gini = 0.278\nsamples = 6\nvalue = [1, 5]'),
 Text(254.73913043478262, 108.72, 'gini = 0.459\nsamples = 14\nvalue = [5, 9]'),
 Text(269.295652173913, 142.1723076923077, 'X[0] <= 56.5\ngini = 0.42\nsamples = 10\nvalue = [7, 3]'),
 Text(262.0173913043478, 125.44615384615385, 'gini = 0.48\nsamples = 5\nvalue = [3, 2]'),
 Text(276.5739130434783, 125.44615384615385, 'gini = 0.32\nsamples = 5\nvalue = [4, 1]'),
 Text(291.1304347826087, 158.89846153846153, 'X[0] <= 58.5\ngini = 0.365\nsamples = 25\nvalue = [6, 19]'),
 Text(283.8521739130435, 142.1723076923077, 'gini = 0.375\nsamples = 12\nvalue = [3, 9]'),
 Text(298.40869565217395, 142.1723076923077, 'gini = 0.355\nsamples = 13\nvalue = [3, 10]'),
 Text(312.96521739130435, 175.62461538461537, 'X[0] <= 60.5\ngini = 0.499\nsamples = 48\nvalue = [25, 23]'),
 Text(305.68695652173915, 158.89846153846153, 'gini = 0.375\nsamples = 12\nvalue = [9, 3]'),
 Text(320.24347826086955, 158.89846153846153, 'X[0] <= 63.5\ngini = 0.494\nsamples = 36\nvalue = [16, 20]'),
 Text(312.96521739130435, 142.1723076923077, 'X[0] <= 62.5\ngini = 0.473\nsamples = 26\nvalue = [10, 16]'),
 Text(305.68695652173915, 125.44615384615385, 'X[0] <= 61.5\ngini = 0.49\nsamples = 21\nvalue = [9, 12]'),
 Text(298.40869565217395, 108.72, 'gini = 0.444\nsamples = 12\nvalue = [4, 8]'),
 Text(312.96521739130435, 108.72, 'gini = 0.494\nsamples = 9\nvalue = [5, 4]'),
 Text(320.24347826086955, 125.44615384615385, 'gini = 0.32\nsamples = 5\nvalue = [1, 4]'),
 Text(327.5217391304348, 142.1723076923077, 'gini = 0.48\nsamples = 10\nvalue = [6, 4]')]
../../_images/Lecture_20-pre_25_1.png

Whoa.

Decision trees can overfit data a lot if they aren’t constrained.

Let’s try this again…

dtree = DecisionTreeClassifier(
    max_depth = 1,
    random_state = 1
)

# Fit the model
dtree.fit(X = xtrain, y = ytrain)
DecisionTreeClassifier(max_depth=1, random_state=1)
tree.plot_tree(dtree,
               feature_names = ['Age'],
               class_names = ['No CDH', 'CDH'],
               filled = True
              )
[Text(167.4, 163.07999999999998, 'Age <= 50.5\ngini = 0.453\nsamples = 346\nvalue = [226, 120]\nclass = No CDH'),
 Text(83.7, 54.360000000000014, 'gini = 0.339\nsamples = 217\nvalue = [170, 47]\nclass = No CDH'),
 Text(251.10000000000002, 54.360000000000014, 'gini = 0.491\nsamples = 129\nvalue = [56, 73]\nclass = CDH')]
../../_images/Lecture_20-pre_28_1.png

What’s going on here?

  • Age <= 50.5: This is the “rule” being used to define leaves on either side of the tree (“No” -> left, “Yes” -> right)

  • gini = 0.453: This refers to the “Gini impurity” of the node. Gini impurity is the loss function used to fit this tree (optimal = 0) (more on this here)

  • samples = 346: This is the number of samples in the group that the node is dividing

  • value = [226, 120]: This is the number of training values on the left (values[0]) and the right (values[1]) of the node

NOTE: With a depth of 1, at the very bottom, we have:

  • 170 people were correctly classified as “No CDH” with this rule (true negatives)

  • 47 people were classified as “No CDH” with this rule incorrectly (false negatives)

  • 56 people were classified as “CDH” with this rule incorrectly (false positives)

  • 73 people were classified as “CDH” with this rule correctly (true positives)

Like other classifiers, the decision tree classifier lets us predict values and has functions for assessing prediction accuracy.

# Accuracy on the data
dtree.score(X = xtrain, y = ytrain)
dtree.score(X = xtest, y = ytest)
0.6810344827586207
ypreds = dtree.predict(X = xtest)
ypreds

# Test "score" above
sum(ypreds == ytest) / len(ypreds)
0.6810344827586207
# The "soft classification" probabilities are just the fraction of training samples for the "true" label 
# in the leaf where this test item ended up

ypreds_soft = dtree.predict_proba(X = xtest)
ypreds_soft
array([[0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.43410853, 0.56589147],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.43410853, 0.56589147],
       [0.43410853, 0.56589147],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.78341014, 0.21658986],
       [0.43410853, 0.56589147]])

We can use the predictions as the basis for betting understanding what the tree is doing:

# This reveals the cutoff(s) chosen by our decision tree! 
train_preds = dtree.predict(X = xtrain)
g = sns.scatterplot(x = xtrain[:, 0], y = ytrain, hue = ytrain == train_preds)
g.axvline(50.5)
# g.axvline(59.5)
# g.axvline(24.5)
<matplotlib.lines.Line2D at 0x7fb7b9e4d7c0>
../../_images/Lecture_20-pre_35_1.png
### YOUR CODE HERE

# Make a similar graph to the above with the test data

We can also draw on the same resources that we talked about for assessing our \(k\)-nearest neighbors classifier

from sklearn.metrics import accuracy_score, f1_score


# Test accuracy
accuracy_score(y_true = ytest, y_pred = dtree.predict(X = xtest))

# Test F1 score
f1_score(y_true = ytest,
         y_pred = dtree.predict(X = xtest),
         labels = [0, 1],
         pos_label = 1
        )
0.5542168674698795
from sklearn.metrics import roc_curve

# ROC curve
fpr, tpr, thresholds = roc_curve(
    y_true = ytest,
    y_score = dtree.predict_proba(X = xtest)[:, 1],
    pos_label = 1
)


sns.lineplot(x = fpr, y = tpr)
plt.axline(xy1 = (0, 0), slope = 1, c = "r")

plt.xlabel("FPR")
plt.ylabel("TPR")
Text(0, 0.5, 'TPR')
../../_images/Lecture_20-pre_39_1.png

Support Vector Machines (SVMs)

Support Vector Machines work by trying to find a line or plane (usually in a high-dimensional space) that maximally separates the training labels in that space.

The intuition for this is relatively straightforward but the implementation can get complicated!

In the plot below, the linear funtion \(h_3(x_1, x_2)\) is the best way to separate our training data because it maximizes the margin on either side of the line.

SVMs try to find the equivalent of \(h_3\) given some training data. This separator can be defined by the closest points in the data to the line; these are called the “support vectors”. Finding the best separator usually requires mapping the training data into a high-dimensional space where it can be effectively separated.

svm

(Source)

SVMs in python

The documentation for SVMs in scikit-learn is here.

from sklearn.svm import SVC

svm = SVC()

svm.fit(xtrain, ytrain)
SVC()

In the case of SVMs, there are class attributes that help you recover the separator that was fit.

We won’t get into these but if you’re interested in learning more it’s good to know about!

# svm.intercept_
# svm.coef_ # only for 'linear' kernel
# svm.support_vectors_

# For example, we can view the items in the training set that formed the support vector
sns.scatterplot(x = xtrain[:, 0], y = ytrain)
plt.title("Training data")
plt.show()

sns.scatterplot(x = xtrain[svm.support_][:, 0], y = ytrain[svm.support_])
plt.title("Support vectors")
plt.show()
../../_images/Lecture_20-pre_44_0.png ../../_images/Lecture_20-pre_44_1.png

The SVM class has a score function that returns the accuracy of a test set, plus prediction functions.

# Percent of correct classifications
svm.score(X = xtrain, y = ytrain)
svm.score(X = xtest, y = ytest)
0.646551724137931
ypreds = svm.predict(X = xtest)
ypreds
array([0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0,
       0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1,
       0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1,
       0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0,
       0, 1, 0, 0, 0, 1])

However, soft prediction requires configuring the initial model to do soft classification (by default, SVMs are made to only do hard classification).

svm_soft = SVC(probability = True) # indicate that you want the SVM to do soft classification
svm_soft.fit(X = xtrain, y = ytrain)

ypreds_soft = svm_soft.predict_proba(X = xtest)
ypreds_soft
array([[0.74990451, 0.25009549],
       [0.75092313, 0.24907687],
       [0.70394272, 0.29605728],
       [0.75049796, 0.24950204],
       [0.42793442, 0.57206558],
       [0.74991807, 0.25008193],
       [0.75470904, 0.24529096],
       [0.50843848, 0.49156152],
       [0.75092313, 0.24907687],
       [0.4323866 , 0.5676134 ],
       [0.46023513, 0.53976487],
       [0.75030375, 0.24969625],
       [0.74991807, 0.25008193],
       [0.43423508, 0.56576492],
       [0.74852213, 0.25147787],
       [0.70394272, 0.29605728],
       [0.75075704, 0.24924296],
       [0.53790416, 0.46209584],
       [0.74852213, 0.25147787],
       [0.75461085, 0.24538915],
       [0.74992658, 0.25007342],
       [0.74992658, 0.25007342],
       [0.7336035 , 0.2663965 ],
       [0.75092313, 0.24907687],
       [0.74990451, 0.25009549],
       [0.44558218, 0.55441782],
       [0.75092313, 0.24907687],
       [0.74299424, 0.25700576],
       [0.44558218, 0.55441782],
       [0.42793442, 0.57206558],
       [0.48215983, 0.51784017],
       [0.74928835, 0.25071165],
       [0.75092313, 0.24907687],
       [0.74794345, 0.25205655],
       [0.74928835, 0.25071165],
       [0.60006623, 0.39993377],
       [0.74994641, 0.25005359],
       [0.44347322, 0.55652678],
       [0.74922983, 0.25077017],
       [0.60006623, 0.39993377],
       [0.74994641, 0.25005359],
       [0.74852213, 0.25147787],
       [0.75298329, 0.24701671],
       [0.56902307, 0.43097693],
       [0.74927988, 0.25072012],
       [0.75049796, 0.24950204],
       [0.46023513, 0.53976487],
       [0.74994641, 0.25005359],
       [0.44347322, 0.55652678],
       [0.74990451, 0.25009549],
       [0.74299424, 0.25700576],
       [0.44558218, 0.55441782],
       [0.50843848, 0.49156152],
       [0.43423508, 0.56576492],
       [0.65601183, 0.34398817],
       [0.75234516, 0.24765484],
       [0.75075704, 0.24924296],
       [0.70394272, 0.29605728],
       [0.44347322, 0.55652678],
       [0.48215983, 0.51784017],
       [0.75470904, 0.24529096],
       [0.75075704, 0.24924296],
       [0.62944664, 0.37055336],
       [0.75030375, 0.24969625],
       [0.56902307, 0.43097693],
       [0.4323866 , 0.5676134 ],
       [0.74299424, 0.25700576],
       [0.74299424, 0.25700576],
       [0.75234516, 0.24765484],
       [0.74928835, 0.25071165],
       [0.74994641, 0.25005359],
       [0.74994641, 0.25005359],
       [0.75470904, 0.24529096],
       [0.75086833, 0.24913167],
       [0.74972677, 0.25027323],
       [0.74951068, 0.25048932],
       [0.74934089, 0.25065911],
       [0.4323866 , 0.5676134 ],
       [0.74991807, 0.25008193],
       [0.75049796, 0.24950204],
       [0.50843848, 0.49156152],
       [0.56902307, 0.43097693],
       [0.74922983, 0.25077017],
       [0.46023513, 0.53976487],
       [0.53790416, 0.46209584],
       [0.74972677, 0.25027323],
       [0.42721767, 0.57278233],
       [0.50843848, 0.49156152],
       [0.7336035 , 0.2663965 ],
       [0.48215983, 0.51784017],
       [0.74794345, 0.25205655],
       [0.7506655 , 0.2493345 ],
       [0.42721767, 0.57278233],
       [0.75470904, 0.24529096],
       [0.42793442, 0.57206558],
       [0.75379628, 0.24620372],
       [0.75092313, 0.24907687],
       [0.75379628, 0.24620372],
       [0.75092313, 0.24907687],
       [0.74990451, 0.25009549],
       [0.75379628, 0.24620372],
       [0.62944664, 0.37055336],
       [0.46023513, 0.53976487],
       [0.74928835, 0.25071165],
       [0.46023513, 0.53976487],
       [0.43423508, 0.56576492],
       [0.74928835, 0.25071165],
       [0.74928835, 0.25071165],
       [0.56902307, 0.43097693],
       [0.60006623, 0.39993377],
       [0.75049796, 0.24950204],
       [0.42721767, 0.57278233],
       [0.75049796, 0.24950204],
       [0.75379628, 0.24620372],
       [0.74299424, 0.25700576],
       [0.42721767, 0.57278233]])

Classifier Wrap-Up

This is just a sample of what’s out there!

There are a number of other common classifiers you should take a look at if you’re interested:

  • Naive Bayes (here)

  • Discriminant analysis (linear and quadratic)

  • Neural networks (here)

  • Random forests (here) (related to decision trees)

  • Gradient boosted trees (here)

The main goal of this lecture is to show you some of the creative ways that people solve classification problems and how the scikit-learn library supports these solutions.

This should empower you to go off and try some of these other ones on your own!