Linear discriminant analysis (LDA), also known as normal discriminant analysis (NDA) or discriminant function analysis (DFA), builds on Fisher's linear discriminant, a statistical approach pioneered by Sir Ronald Fisher. It is a dimensionality reduction technique that is used in supervised machine learning.
The primary function of LDA is to project high-dimensional data on to a lower-dimensional space while retaining the data's inherent class separability. LDA can be applied to enhance the operation of classification algorithms such as a decision tree or random forest.
LDA adheres to a generative model framework, which models the data distribution for each class and uses Bayes' theorem to classify new data points. It operates under two main assumptions: the data follows a normal or Gaussian distribution and the covariance matrices of the classes are equal.
This dimensionality reduction technique separates data points by using linear transformations, which are analyzed by using eigenvectors and eigenvalues. When plotted on a 2-dimensional plane, vectors provide magnitude and direction. Eigenvectors represent directions, while eigenvalues represent magnitude or significance. Eigenvectors are calculated from the data set and collected in two class scatter matrices the between-class scatter matrix and the within-class scatter matrix.
LDA is used to identify a linear combination of features that characterizes or separates two or more classes of objects or events. This versatility ensures that LDA can be used for binary and multi-class classification problems, unlike logistic regression, which is limited to binary classification.
In this Python tutorial, we delve deeper into LDA with Python, implementing LDA to optimize a machine learning model's performance by using the popular Iris data set. The goal is to classify three species of iris flowers based on four features: sepal length, sepal width, petal length, and petal width. By implementing LDA, we can effectively reduce the dimensionality of the data set and enhance the classification accuracy of the machine learning (ML) model.
While you can choose from several tools, this tutorial walks you through how to set up an IBM account to use a Jupyter Notebook. Jupyter Notebooks are widely used within data science to combine code, text, images, and data visualizations to formulate a well-formed analysis.
Log in to watsonx.ai using your IBM Cloud account.
From here, a notebook environment opens for you to load your data set and copy code from this beginner tutorial to tackle a simple classification problem.
Step 2: Install and import relevant libraries
We'll need a few libraries for this tutorial. Make sure to import the necessary Python libraries that we need to work with our Iris data set, perform data preprocessing, and create and evaluate our LDA model. If they're not installed, you can resolve this with a quick pip install.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
Show more
In this step, we imported essential Python libraries such as NumPy, pandas, Matplotlib, and scikit-learn. These libraries are crucial for data manipulation, visualization, and machine learning tasks. As a result, you now have access to the libraries that you need for this tutorial.
Step 3: Read and load the data
In this step, you read the Iris data set from UCI Machine Learning Repository and assign meaningful column names. The data set contains information about the sepal and petal dimensions of three different species of iris.
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"# Define column names
cls = ['sepal-length', 'sepal-width', 'petal-length', 'petal-width', 'Class']
# Read the data set
dataset = pd.read_csv(url, names=cls)
Show more
Step 4: Preprocess the data
In this step, you separate the feature variables (X) and the target variable (y), and encode the target variable y numerically.
# Divide the data set into features (X) and target variable (y)
X = dataset.iloc[:, 0:4].values
y = dataset.iloc[:, 4].values
# Encode the target variable
le = LabelEncoder()
y = le.fit_transform(y)
Show more
In this step, we divided the data set into two parts: the independent variables (features, represented as X) and the dependent variable (target class, represented as y). This step involves label encoding which converts class labels into numerical values. The result of this step is a preprocessed data set ready for further analysis.
Step 5: Perform exploratory data analysis
Before you implement LDA, it's essential to analyze the data set and understand its characteristics. In this step, we perform an exploratory analysis of the Iris data set by using a pair plot, histogram, and correlation heatmap.
Pair plots
The pair plot effectively illustrates how the four features vary and correlate across the three classes of iris flowers. In addition, it provides insights into the separability and overlap of these classes within the feature space, facilitating a deeper understanding of the characteristics and differences among the classes. The diagonal plot elements showcase the distribution of each feature, while the off-diagonal plot elements display scatterplots for each pair of features.
# Create a pair plot to visualize relationships between different features and species.
ax = sns.pairplot(dataset, hue='Class', markers=["o", "s", "D"])
plt.suptitle("Pair Plot of Iris Dataset")
sns.move_legend(
ax, "lower center",
bbox_to_anchor=(.5, 1), ncol=3, title=None, frameon=False)
plt.tight_layout()
plt.show()
Show more
Pair plot
For instance, look at the first column of the plot, which displays the relationship between sepal length and the other features. We can infer the following insights from the pair plot:
The sepal length follows a slightly right-skewed normal distribution.
Iris virginica has the longest sepal length, followed by Iris versicolor and Iris setosa. Iris setosa has the least variation in sepal length, while Iris virginica exhibits the most.
Sepal length negatively correlates with sepal width. Iris setosa has the widest sepal width, followed by Iris versicolor and Iris virginica. Also, Iris setosa is well separated from the other two classes, while Iris versicolor and Iris virginica have some overlap in this feature pair.
Sepal length positively correlates with petal length. Iris virginica has the longest petal length, followed by Iris versicolor and Iris setosa. Here again, Iris setosa is well separated, while versicolor and virginica overlap.
Sepal length positively correlates with petal width. Iris virginica has the widest petal width, followed by Iris versicolor and Iris setosa. Here again, Iris setosa is well-separated, while versicolor and virginica overlap.
You can apply similar analyses to the remaining columns, maintaining this pattern of feature relationships and class distinctions.
Histograms
Histograms are valuable tools for visualizing the distribution of individual features within the Iris data set. This graph comprises four histograms illustrating the distribution of petal and sepal lengths and widths for the three classes of irises.
# Visualize the distribution of each feature using histograms.
plt.figure(figsize=(12, 6))
for i, feature inenumerate(cls[:-1]):
plt.subplot(2, 2, i + 1)
sns.histplot(data=dataset, x=feature, hue='Class', kde=True)
plt.title(f'{feature} Distribution')
plt.tight_layout()
plt.show()
Show more
Histogram
These histograms indicate a balanced data set, but they also showcase the differences among the three classes of iris flowers in terms of petal and sepal lengths and widths, allowing for the following conclusions:
Iris setosa is characterized by shorter and narrower petals and sepals compared to Iris versicolor and Iris virginica. Its histograms display the lowest values for petal and sepal lengths and widths. Its distributions are peaked and narrow, reflecting minimal variation and high consistency in petal and sepal measurements.
Conversely, Iris versicolor and Iris virginica exhibit longer and wider petals and sepals. Their histograms display higher values for petal and sepal lengths and widths. Their distributions are flatter and wider, indicating greater variation and less consistency in petal and sepal measurements.
While Iris versicolor and Iris virginica share similar sepal lengths and widths, their petal measurements differentiate them. The histograms reveal overlapping distributions for sepal dimensions but separated distributions for petal dimensions. This emphasizes that sepal measurements alone cannot distinguish between Iris versicolor and Iris virginica. Petal measurements offer more effective differentiation.
Correlation heatmaps
A correlation heatmap offers insights into the relationships between different features in the Iris data set. It plays a critical role in feature selection as it helps identify positive and negative correlations between two variables. The variables include sepal length, sepal width, petal length, and petal width, which are measured in centimeters (cm).
The values in the heatmap range from -0.42 to 1, with the values closer to 1 signifying strong positive correlations and vice versa. We can draw the following conclusions from the heatmap:
Sepal length has a strong positive correlation with petal length and petal width and a weak negative correlation with sepal width.
Sepal width has a negative correlation with petal length and petal width and a weak negative correlation with sepal length.
Petal length has a strong positive correlation with sepal length and petal width and a negative correlation with sepal width.
Petal width has a strong positive correlation with sepal length and petal length and a negative correlation with sepal width.
Step 6: Split the data set
Next, to evaluate the LDA model, we must split the data set into two distinct sets: the training set and the testing set. The training set trains the LDA model to differentiate between different classes based on the features provided. The LDA model learns from the patterns and relationships present in this training set. On the other hand, the testing set evaluates the model's performance and assesses how well it can generalize its learning to new, unseen data. This step is critical to ensure that the LDA model can accurately predict new, real-world data.
# Split the data set into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
Show more
Step 7: Implement LDA
In this step, LDA was applied to the training data set, transforming the four-dimensional feature space into a two-dimensional subspace while maximizing the separation between classes. The outcome was a training data set with reduced dimensions.
Next, let's visualize the reduced-dimensional data to understand the separability of the classes. We use a scatterplot to help us visualize the transformed data points.
In this step, we classify the data using the random forest classifier. We specifically train the model on the LDA-transformed training data, enabling it to make predictions based on the reduced-dimensional LDA features. The outcome is the predicted class labels for the test data set.
In the final step, we evaluate the accuracy of the classification model metric, accuracy score. The accuracy score value is typically used to assess performance of classification tasks on balanced data sets.
#Assume 'y_test' and 'y_pred' are already defined
accuracy = accuracy_score(y_test, y_pred)
conf_m = confusion_matrix(y_test, y_pred)
#Display the accuracyprint(f'Accuracy: {accuracy:.2f}')
#Display the confusion matrix as a heatmap
plt.figure(figsize=(6, 6))
sns.heatmap(conf_m, annot=True, fmt="d", cmap="Blues", cbar=False, square=True)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.show()
A confusion matrix can help us visualize these metrics more easily. The confusion matrix is divided into four parts: true positives, true negatives, false positives, and false negatives.
In this confusion matrix, the number of correct predictions is 29, and the total number of predictions is 30. To compute accuracy, we will calculate: (Number of correct predictions)/(Total number of predictions) = 29/30 = 0.9667.
By projecting the data set with LDA, the random forest model is able to classify the different samples with an accuracy score of 97%, indicating great model performance.
Summary and next steps
In this tutorial, you learned how to apply LDA to optimize a machine learning model's performance using the Iris data set. You discovered how LDA reduces the dimensionality of the data set, which makes it easier to visualize and classify the data. You also leveraged data visualization techniques to gain insights into class separability and evaluated the accuracy of the LDA model.
Try watsonx for free
Build an artificial intelligence (AI) strategy for your business on one collaborative AI and data platform called IBM watsonx, which combines new generative AI capabilities powered by foundation models and traditional machine learning into a powerful platform spanning the AI lifecycle. With watsonx.ai, you can train, validate, tune, and deploy models with ease and build AI applications in a fraction of the time with a fraction of the data.
Try watsonx.ai, the next-generation studio for AI builders.
Also, to learn more about classification or regression algorithms using Python and scikit-learn, see these tutorials in the Getting started with machine learning learning path:
About cookies on this siteOur websites require some cookies to function properly (required). In addition, other cookies may be used with your consent to analyze site usage, improve the user experience and for advertising.For more information, please review your cookie preferences options. By visiting our website, you agree to our processing of information as described in IBM’sprivacy statement. To provide a smooth navigation, your cookie preferences will be shared across the IBM web domains listed here.