Week 7, Wed, 5/15

Week 7, Wed, 5/15#

import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.inspection import DecisionBoundaryDisplay

# Set random seed for reproducibility
np.random.seed(0)

# Number of samples per class
N = 100

# Generate data: class 1 centered at (1, 1) and class 2 at (-2, -1)
std = 5
x_class1 = np.random.multivariate_normal([1, 1], std*np.eye(2), N)
x_class2 = np.random.multivariate_normal([-2, -1], std*np.eye(2), N)

# Combine into a single dataset
X = np.vstack((x_class1, x_class2))
y = np.concatenate((np.ones(N), np.zeros(N)))

# Create a logistic regression classifier
clf = LogisticRegression()
clf.fit(X, y)

# Plot the decision boundaries using DecisionBoundaryDisplay
fig, ax = plt.subplots()
db_display = DecisionBoundaryDisplay.from_estimator(
    clf,
    X,
    grid_resolution=200,
    # response_method="predict_proba",  # Can be "predict_proba" for probability contours
    response_method="predict",
    cmap='coolwarm',
    alpha=0.5,
    ax=ax
)
# Scatter plot of the data points
scatter = ax.scatter(X[:, 0], X[:, 1], c=y, edgecolors='k', cmap='coolwarm')

# scatter plot of the mean of gaussian distributions
ax.scatter([1, -2], [1, -1], c='yellow', marker='x', s=100, label='Class Centers')

# Adding color bar
cbar = plt.colorbar(scatter, ax=ax)
# Adding title and labels
ax.set_title('Logistic Regression Decision Boundary')
ax.set_xlabel('Feature 1')
ax.set_ylabel('Feature 2')
Text(0, 0.5, 'Feature 2')
../_images/eeb4fafff4b8379cbdd0acfe421cb5abf79cfc08f32524894a0777b32ccd6a10.png
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler


# Load the dataset
df = sns.load_dataset('penguins')

# Drop rows with missing values
df.dropna(inplace=True)

# Use features to predict sex
features = ['bill_length_mm', 'bill_depth_mm']

# Select features
X = df[features]
y = df['sex']

# Initialize and train the logistic regression model
clf = LogisticRegression(penalty=None)
clf.fit(X, y)

# prediction
y_pred = clf.predict(X)

# Calculate the training and test accuracy
score = clf.score(X, y)
print(f"Training accuracy: {score:.2f}")
Training accuracy: 0.78
y
0        Male
1      Female
2      Female
4      Female
5        Male
        ...  
338    Female
340    Female
341      Male
342    Female
343      Male
Name: sex, Length: 333, dtype: object
clf.classes_
array(['Female', 'Male'], dtype=object)
y_pred[:5]
array(['Male', 'Female', 'Female', 'Female', 'Male'], dtype=object)
y_pred_proba = clf.predict_proba(X)
y_pred_proba
array([[0.49992132, 0.50007868],
       [0.70860573, 0.29139427],
       [0.55266441, 0.44733559],
       [0.54734163, 0.45265837],
       [0.18061042, 0.81938958],
       [0.67795367, 0.32204633],
       [0.32785431, 0.67214569],
       [0.57480085, 0.42519915],
       [0.14395903, 0.85604097],
       [0.34937649, 0.65062351],
       [0.79703243, 0.20296757],
       [0.46952601, 0.53047399],
       [0.0789741 , 0.9210259 ],
       [0.81815319, 0.18184681],
       [0.01765569, 0.98234431],
       [0.65902686, 0.34097314],
       [0.59367678, 0.40632322],
       [0.6185772 , 0.3814228 ],
       [0.66906797, 0.33093203],
       [0.77413579, 0.22586421],
       [0.70605864, 0.29394136],
       [0.4181914 , 0.5818086 ],
       [0.55822167, 0.44177833],
       [0.59909414, 0.40090586],
       [0.36975771, 0.63024229],
       [0.80622499, 0.19377501],
       [0.72612386, 0.27387614],
       [0.64146776, 0.35853224],
       [0.3448657 , 0.6551343 ],
       [0.88450766, 0.11549234],
       [0.13369079, 0.86630921],
       [0.28570041, 0.71429959],
       [0.33465689, 0.66534311],
       [0.48649596, 0.51350404],
       [0.37825844, 0.62174156],
       [0.7758277 , 0.2241723 ],
       [0.4425114 , 0.5574886 ],
       [0.72976869, 0.27023131],
       [0.10689291, 0.89310709],
       [0.8754362 , 0.1245638 ],
       [0.44705642, 0.55294358],
       [0.31591262, 0.68408738],
       [0.81057826, 0.18942174],
       [0.05809877, 0.94190123],
       [0.65279632, 0.34720368],
       [0.39536213, 0.60463787],
       [0.84875171, 0.15124829],
       [0.19778644, 0.80221356],
       [0.84643884, 0.15356116],
       [0.3665468 , 0.6334532 ],
       [0.72059957, 0.27940043],
       [0.38139221, 0.61860779],
       [0.91016093, 0.08983907],
       [0.52483244, 0.47516756],
       [0.90906716, 0.09093284],
       [0.08032323, 0.91967677],
       [0.84690639, 0.15309361],
       [0.46036916, 0.53963084],
       [0.87643623, 0.12356377],
       [0.46482029, 0.53517971],
       [0.94752693, 0.05247307],
       [0.29957123, 0.70042877],
       [0.92260255, 0.07739745],
       [0.2193853 , 0.7806147 ],
       [0.78372834, 0.21627166],
       [0.51679452, 0.48320548],
       [0.73398887, 0.26601113],
       [0.12239555, 0.87760445],
       [0.86945966, 0.13054034],
       [0.29946636, 0.70053364],
       [0.72500863, 0.27499137],
       [0.49442   , 0.50558   ],
       [0.94161433, 0.05838567],
       [0.24593479, 0.75406521],
       [0.91453554, 0.08546446],
       [0.4535187 , 0.5464813 ],
       [0.63957906, 0.36042094],
       [0.63342309, 0.36657691],
       [0.76460578, 0.23539422],
       [0.13893067, 0.86106933],
       [0.53616758, 0.46383242],
       [0.66212181, 0.33787819],
       [0.45833204, 0.54166796],
       [0.49429506, 0.50570494],
       [0.81128372, 0.18871628],
       [0.4794775 , 0.5205225 ],
       [0.93148193, 0.06851807],
       [0.58041369, 0.41958631],
       [0.86528297, 0.13471703],
       [0.3510156 , 0.6489844 ],
       [0.58600579, 0.41399421],
       [0.45706592, 0.54293408],
       [0.97393747, 0.02606253],
       [0.27722489, 0.72277511],
       [0.84875171, 0.15124829],
       [0.18053646, 0.81946354],
       [0.92060807, 0.07939193],
       [0.34405245, 0.65594755],
       [0.59909414, 0.40090586],
       [0.42155364, 0.57844636],
       [0.78347408, 0.21652592],
       [0.32001225, 0.67998775],
       [0.82849588, 0.17150412],
       [0.2072025 , 0.7927975 ],
       [0.87638209, 0.12361791],
       [0.04788837, 0.95211163],
       [0.64662721, 0.35337279],
       [0.18932511, 0.81067489],
       [0.15838961, 0.84161039],
       [0.33865534, 0.66134466],
       [0.80837234, 0.19162766],
       [0.2904184 , 0.7095816 ],
       [0.90252297, 0.09747703],
       [0.38562542, 0.61437458],
       [0.8739782 , 0.1260218 ],
       [0.38586228, 0.61413772],
       [0.73218826, 0.26781174],
       [0.3845359 , 0.6154641 ],
       [0.96102343, 0.03897657],
       [0.345906  , 0.654094  ],
       [0.71604589, 0.28395411],
       [0.41483675, 0.58516325],
       [0.77804667, 0.22195333],
       [0.30605418, 0.69394582],
       [0.68483722, 0.31516278],
       [0.18721562, 0.81278438],
       [0.68494509, 0.31505491],
       [0.64264007, 0.35735993],
       [0.75300095, 0.24699905],
       [0.59343562, 0.40656438],
       [0.86635207, 0.13364793],
       [0.20220334, 0.79779666],
       [0.90523427, 0.09476573],
       [0.61083073, 0.38916927],
       [0.71687905, 0.28312095],
       [0.67784453, 0.32215547],
       [0.98728584, 0.01271416],
       [0.70479093, 0.29520907],
       [0.87493358, 0.12506642],
       [0.50669772, 0.49330228],
       [0.51232213, 0.48767787],
       [0.71249375, 0.28750625],
       [0.82207834, 0.17792166],
       [0.692621  , 0.307379  ],
       [0.88770751, 0.11229249],
       [0.3781409 , 0.6218591 ],
       [0.19901148, 0.80098852],
       [0.02741778, 0.97258222],
       [0.02433995, 0.97566005],
       [0.15341443, 0.84658557],
       [0.01065706, 0.98934294],
       [0.27620396, 0.72379604],
       [0.18031474, 0.81968526],
       [0.05098994, 0.94901006],
       [0.11669031, 0.88330969],
       [0.01437142, 0.98562858],
       [0.20703834, 0.79296166],
       [0.00953364, 0.99046636],
       [0.2558412 , 0.7441588 ],
       [0.04578987, 0.95421013],
       [0.3506741 , 0.6493259 ],
       [0.02229121, 0.97770879],
       [0.01740077, 0.98259923],
       [0.01173829, 0.98826171],
       [0.12983483, 0.87016517],
       [0.0867056 , 0.9132944 ],
       [0.54468885, 0.45531115],
       [0.16413932, 0.83586068],
       [0.62232281, 0.37767719],
       [0.02521726, 0.97478274],
       [0.19051019, 0.80948981],
       [0.02349242, 0.97650758],
       [0.05414753, 0.94585247],
       [0.04523367, 0.95476633],
       [0.21608031, 0.78391969],
       [0.00891233, 0.99108767],
       [0.75452341, 0.24547659],
       [0.00331945, 0.99668055],
       [0.64849927, 0.35150073],
       [0.03547553, 0.96452447],
       [0.05748881, 0.94251119],
       [0.30584192, 0.69415808],
       [0.11946092, 0.88053908],
       [0.00648632, 0.99351368],
       [0.3767074 , 0.6232926 ],
       [0.00796741, 0.99203259],
       [0.03565049, 0.96434951],
       [0.2681006 , 0.7318994 ],
       [0.02914635, 0.97085365],
       [0.3939047 , 0.6060953 ],
       [0.07009696, 0.92990304],
       [0.04660079, 0.95339921],
       [0.08561875, 0.91438125],
       [0.03310481, 0.96689519],
       [0.03351711, 0.96648289],
       [0.13863197, 0.86136803],
       [0.33729035, 0.66270965],
       [0.02752736, 0.97247264],
       [0.32842793, 0.67157207],
       [0.01960917, 0.98039083],
       [0.53795845, 0.46204155],
       [0.02587964, 0.97412036],
       [0.4893204 , 0.5106796 ],
       [0.02446138, 0.97553862],
       [0.04369119, 0.95630881],
       [0.09117882, 0.90882118],
       [0.01656434, 0.98343566],
       [0.40137969, 0.59862031],
       [0.38103851, 0.61896149],
       [0.00462731, 0.99537269],
       [0.32460055, 0.67539945],
       [0.07849534, 0.92150466],
       [0.03223244, 0.96776756],
       [0.04701135, 0.95298865],
       [0.91066689, 0.08933311],
       [0.24718002, 0.75281998],
       [0.71637155, 0.28362845],
       [0.43296001, 0.56703999],
       [0.71460002, 0.28539998],
       [0.87902345, 0.12097655],
       [0.80806245, 0.19193755],
       [0.63367904, 0.36632096],
       [0.94918276, 0.05081724],
       [0.60926056, 0.39073944],
       [0.96603197, 0.03396803],
       [0.33421187, 0.66578813],
       [0.89098538, 0.10901462],
       [0.65118528, 0.34881472],
       [0.79068188, 0.20931812],
       [0.38612357, 0.61387643],
       [0.96094847, 0.03905153],
       [0.48677137, 0.51322863],
       [0.78538564, 0.21461436],
       [0.53974834, 0.46025166],
       [0.59060936, 0.40939064],
       [0.8313901 , 0.1686099 ],
       [0.77136072, 0.22863928],
       [0.56779474, 0.43220526],
       [0.96324225, 0.03675775],
       [0.70352005, 0.29647995],
       [0.61775137, 0.38224863],
       [0.71271953, 0.28728047],
       [0.41423023, 0.58576977],
       [0.59517232, 0.40482768],
       [0.92049841, 0.07950159],
       [0.8313901 , 0.1686099 ],
       [0.01402149, 0.98597851],
       [0.56980658, 0.43019342],
       [0.33626264, 0.66373736],
       [0.94719778, 0.05280222],
       [0.41025525, 0.58974475],
       [0.9298361 , 0.0701639 ],
       [0.42531475, 0.57468525],
       [0.94582549, 0.05417451],
       [0.31537278, 0.68462722],
       [0.8961407 , 0.1038593 ],
       [0.49802131, 0.50197869],
       [0.28042318, 0.71957682],
       [0.92146656, 0.07853344],
       [0.87516342, 0.12483658],
       [0.28042318, 0.71957682],
       [0.92893019, 0.07106981],
       [0.63900278, 0.36099722],
       [0.80519174, 0.19480826],
       [0.72725666, 0.27274334],
       [0.84187315, 0.15812685],
       [0.46419867, 0.53580133],
       [0.78460908, 0.21539092],
       [0.71750816, 0.28249184],
       [0.91745255, 0.08254745],
       [0.75091166, 0.24908834],
       [0.89192418, 0.10807582],
       [0.38818838, 0.61181162],
       [0.88878004, 0.11121996],
       [0.72917688, 0.27082312],
       [0.86911891, 0.13088109],
       [0.13955436, 0.86044564],
       [0.83697931, 0.16302069],
       [0.19106623, 0.80893377],
       [0.28875204, 0.71124796],
       [0.90545724, 0.09454276],
       [0.42406843, 0.57593157],
       [0.62413156, 0.37586844],
       [0.59849376, 0.40150624],
       [0.4778304 , 0.5221696 ],
       [0.72826729, 0.27173271],
       [0.70255944, 0.29744056],
       [0.37659005, 0.62340995],
       [0.7640656 , 0.2359344 ],
       [0.19593638, 0.80406362],
       [0.89143149, 0.10856851],
       [0.52765075, 0.47234925],
       [0.62509321, 0.37490679],
       [0.2219487 , 0.7780513 ],
       [0.70998791, 0.29001209],
       [0.30856695, 0.69143305],
       [0.89905503, 0.10094497],
       [0.10931365, 0.89068635],
       [0.8940746 , 0.1059254 ],
       [0.53619314, 0.46380686],
       [0.79059915, 0.20940085],
       [0.0979685 , 0.9020315 ],
       [0.6367617 , 0.3632383 ],
       [0.08916842, 0.91083158],
       [0.81563997, 0.18436003],
       [0.30476004, 0.69523996],
       [0.83267616, 0.16732384],
       [0.29520349, 0.70479651],
       [0.36182223, 0.63817777],
       [0.73708284, 0.26291716],
       [0.68518294, 0.31481706],
       [0.17288326, 0.82711674],
       [0.57206086, 0.42793914],
       [0.03732412, 0.96267588],
       [0.56440418, 0.43559582],
       [0.53186206, 0.46813794],
       [0.47680717, 0.52319283],
       [0.91401823, 0.08598177],
       [0.16087232, 0.83912768],
       [0.92179888, 0.07820112],
       [0.57980495, 0.42019505],
       [0.40003471, 0.59996529],
       [0.32025219, 0.67974781],
       [0.81640578, 0.18359422],
       [0.17941481, 0.82058519],
       [0.83260651, 0.16739349],
       [0.09397696, 0.90602304],
       [0.32922279, 0.67077721],
       [0.83753815, 0.16246185],
       [0.78383047, 0.21616953],
       [0.31825225, 0.68174775],
       [0.79220083, 0.20779917],
       [0.28228383, 0.71771617]])
import seaborn as sns
# Plot the decision boundaries using DecisionBoundaryDisplay
fig, ax = plt.subplots()
db_display = DecisionBoundaryDisplay.from_estimator(
    clf,
    X,
    grid_resolution=200,
    response_method="predict",  # Can be "predict_proba" for probability contours
    cmap='coolwarm',
    alpha=0.5,
    ax=ax
)
# Scatter plot of the data points
scatter = sns.scatterplot(data=df, x=df[features[0]], y=df[features[1]], hue='sex')
../_images/bae97c0c0b41ce4b9c71b97562a8d9bc490d45c500d0f57a7129a4016894a26e.png
from sklearn.metrics import confusion_matrix
conf_matrix = confusion_matrix(y, y_pred)
conf_matrix
array([[130,  35],
       [ 37, 131]])
# Plotting the confusion matrix
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=clf.classes_, yticklabels=clf.classes_)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix of Penguin Sex Prediction')
plt.show()
../_images/06f17d798ce094253c640f89aacc8c2edcc6356d8039ceafc92de38b4bfa7806.png
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.metrics import accuracy_score
from sklearn.inspection import DecisionBoundaryDisplay


# Step 1: Generate data
np.random.seed(0)
n_samples = 1000

r = 1
s = np.sqrt(2*np.pi) * r
x1 = np.random.uniform(-s/2, s/2, n_samples)
x2 = np.random.uniform(-s/2, s/2, n_samples)
X = np.vstack((x1, x2)).T
y = (x1**2 + x2**2 > r).astype(int)

# Step 2: Visualize the data
sns.scatterplot(x=x1, y=x2, hue=y, palette='bwr')
<Axes: >
../_images/c838eafc55f3d5d3e5b378d7be1d9d3f03fc829d51e6ced3b51a22961a5bb2e9.png
# Step 3: Logistic regression without polynomial features
model = LogisticRegression()
model.fit(X, y)
y_pred = model.predict(X)

# Step 4: Plot decision boundary (without polynomial features)
plt.figure()
disp = DecisionBoundaryDisplay.from_estimator(model, X, response_method="predict", alpha=0.3, cmap='bwr')
sns.scatterplot(x=x1, y=x2, hue=y, palette='bwr')
plt.xlabel("x1")
plt.ylabel("x2")
plt.title("Logistic Regression (Without Poly Features)")

acc = model.score(X, y)
print(f"Accuracy: {acc}")
Accuracy: 0.442
<Figure size 640x480 with 0 Axes>
../_images/e4e673171a75b5bd0b1435dbd666e7893200cf8ec2ad7e4aef7d4bf6fde78949.png
X_poly = np.vstack((x1, x2, x1**2, x2**2)).T

model_poly = LogisticRegression()
model_poly.fit(X_poly, y)
y_poly_pred = model_poly.predict(X_poly)

acc = model_poly.score(X_poly, y)
print(f"Accuracy: {acc}")

model_poly.coef_
Accuracy: 0.994
array([[0.134259  , 0.13628499, 7.31227108, 7.34633247]])
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler


# Load the dataset
df = sns.load_dataset('penguins')

# Drop rows with missing values
df.dropna(inplace=True)

features = ['bill_length_mm', 'bill_depth_mm']

# Select features
X = df[features]
y = df['species']


# Initialize and train the logistic regression model
clf = LogisticRegression()
clf.fit(X, y)

# Calculate the training and test accuracy
score = clf.score(X, y)
print(f"Accuracy: {score:.2f}")


# Predict on the test set
y_pred = clf.predict(X)
Accuracy: 0.96
# Evaluate the model
conf_matrix = confusion_matrix(y, y_pred)
print("Confusion Matrix:\n", conf_matrix)

# Plotting the confusion matrix
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=clf.classes_, yticklabels=clf.classes_)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix of Species Classification')
plt.show()
Confusion Matrix:
 [[144   2   0]
 [  4  60   4]
 [  0   2 117]]
../_images/499fa31c25fb51396217f7167b669d30afc1a5f7011b276f6b82219ceacc2845.png
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.inspection import DecisionBoundaryDisplay

fig, ax = plt.subplots()
db_display = DecisionBoundaryDisplay.from_estimator(
    clf,
    X,
    grid_resolution=200,
    response_method="predict",  # Can be "predict_proba" for probability contours
    cmap='coolwarm',
    alpha=0.5,
    ax=ax
)

# Scatter plot of the data points
scatter = sns.scatterplot(data=df, x='bill_length_mm', y='bill_depth_mm', hue='species')

# Adding title and labels
ax.set_title('3-Class Logistic Regression Decision Boundary')
ax.set_xlabel(features[0])
ax.set_ylabel(features[1])

# Show plot
plt.show()
../_images/b5c6ff3e88a27ce9d55395c1bb653fecbd800ccdd43757702071121f6bcbaa81.png