(Estimated reading time: 20 minutes)
This tutorial provides an overview of statistical data visualization in Python using a library called seaborn. Data visualization is an important component in data science, and is usually the first step in a data analysis process after data collection and preprocessing. It allows researchers to have a quick sense of the distribution of their data and potential relationships, which helps shed a light on directions of further analysis on the data. It is also a powerful yet intuitive way of presenting data and conveying messages to any target audience.
The tutorial will briefly introduce different choices of graphs, explain their advantages, disadvantages, the scenarios under which they can be used, as well as ways to code them using seaborn. After the tutorials, readers will have a basic understanding of common data visualization methods and be able to implement and customize them in Python using seaborn.
The tutorial consists of the following sections:
Seaborn is only supported on Python 3.6+. To check your python version, run the following code chunk or copy the command (without the exclaimation mark) and run it from command line.
!python -V
Python 3.8.8
If you need to install or upgrade Python, check out the official website.
Once the correct Python version has been installed, we can install the library from PyPI
or Anaconda
by running one of the following two lines of code or copying one of the two commands (without the exclamation mark) and run it from the command line.
# installing using pip
!pip install seaborn
# installing using conda
#!conda install seaborn
Required dependencies of seaborn include NumPy, SciPy, pandas, and matplotlib. When seaborn is installed, it will automatically check for these required libraries and install them if needed. If you run into any trouble during installation, visit Installing and getting started on seaborn's official website for more information.
After successful installation, we will load the libraries by running the following code. According to recommendations on seaborn's official website, we will import both seaborn and its four required dependencies for more comprehensive functionalities.
import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
The sample data we will be using in this tutorial is the daily data on the COVID-19 pandemic in California, Pennsylvania and Massachusetts, United States, from which is provided by the COVID Tracking Project at The Atlantic under a Creative Commons CC BY 4.0 license.
# read the data from the csv files
covid_ca = pd.read_csv("https://covidtracking.com/data/download/california-history.csv")
covid_pa = pd.read_csv("https://covidtracking.com/data/download/pennsylvania-history.csv")
covid_ma = pd.read_csv("https://covidtracking.com/data/download/massachusetts-history.csv")
# concatenate two data frames
covid = pd.concat([covid_ca, covid_pa, covid_ma])
To keep our demonstration simple, we will retain five variables as shown below, and only keep observations between April 1st, 2020 to December 31st, 2020. For demonstration purpose, we will add another categorical variable, hospitalizedLevel
, based on hospitalizedCurrently
.
# retain only the five columns listed
covid = covid[["date", "state", "deathIncrease", "hospitalizedCurrently", "totalTestResultsIncrease"]]
# change the date from string to datetime
covid["date"] = pd.to_datetime(covid["date"], format = "%Y-%m-%d")
# keep data between 2020-04-01 to 2020-12-31
covid = covid[(covid["date"] >= "2020-04-01") & (covid["date"] <= "2020-12-31")]
# add a new column hospitalizedLevel
def addHospitalizedLevel(row):
if row["hospitalizedCurrently"] >= 10000:
return "high"
elif row["hospitalizedCurrently"] < 5000:
return "low"
else:
return "medium"
covid["hospitalizedLevel"] = covid.apply(addHospitalizedLevel, axis = 1)
# specify categorical variables
covid = covid.astype({"state": "category", "hospitalizedLevel":"category"})
# show first five row of the data frame
covid.head()
date | state | deathIncrease | hospitalizedCurrently | totalTestResultsIncrease | hospitalizedLevel | |
---|---|---|---|---|---|---|
66 | 2020-12-31 | CA | 428 | 21449.0 | 232406 | high |
67 | 2020-12-30 | CA | 432 | 21433.0 | 248605 | high |
68 | 2020-12-29 | CA | 242 | 21240.0 | 245955 | high |
69 | 2020-12-28 | CA | 64 | 20642.0 | 301820 | high |
70 | 2020-12-27 | CA | 237 | 20059.0 | 380154 | high |
# show dimension of the data frame
covid.shape
(825, 6)
As shown above, our final data frame consists of 825 rows (observations) and 6 columns (variables). The columns contain date
, two categorical variables, and three quantitative variables.
Now that we have preprocessed our data, we will move on to using seaborn for visualization.
When we have a 1D quantitative data, some information we may be interested in include its mean (average), median (50% quantile), range, spread, shape of distribution, etc. Below we will see several methods that focus on different information.
A histogram can show the distribution of a quantitative variable. On the x-axis is the variable we are visualizing, whose values are separated into different bins, and on the y-axis is usually the count or density of the observations that fall into each bin.
Let's try to visualize the distribution of deathIncrease
, the daily increase in death counts in a state.
# set the figure size using matplotlib
plt.figure(figsize=(8, 5))
# histogram with count on the y-axis
sns.histplot(x = "deathIncrease", data = covid)
<AxesSubplot:xlabel='deathIncrease', ylabel='Count'>
By default, histplot
puts count on the y-axis. To put density on the y-axis, we specify stat = "density"
, which normalizes the counts to make the total area of bins equal to 1.
plt.figure(figsize=(8, 5))
# histogram with density on the y-axis
sns.histplot(x = "deathIncrease", data = covid, stat = "density")
<AxesSubplot:xlabel='deathIncrease', ylabel='Density'>
There are two other values we can use for the stat
argument: frequency
, which is the count of each bin divided by the bin width, and probability
, which normalizes the counts to make the total height of bars equal to 1.
We can easily change the width of the bins by specifying bins
or binwidth
. If bins
is set to one value, the histogram uses this value as the total number of bins. If bins
is set to a vector or tuple of values, the histogram uses these values as the breaks of the bins.
plt.figure(figsize=(8, 5))
# specify the total number of bins
sns.histplot(x = "deathIncrease", data = covid, bins = 10)
<AxesSubplot:xlabel='deathIncrease', ylabel='Count'>
plt.figure(figsize=(8, 5))
# specify the breaks of the bins
sns.histplot(x = "deathIncrease", data = covid,
bins = (0, 100, 200, 300, 400, 500, 600))
<AxesSubplot:xlabel='deathIncrease', ylabel='Count'>
binwidth
allows us to control the width of each bin. If both bins
and binwidth
are specified, binwidth
will override bins
.
plt.figure(figsize=(8, 5))
# specify the width of each bin
sns.histplot(x = "deathIncrease", data = covid, binwidth = 50)
<AxesSubplot:xlabel='deathIncrease', ylabel='Count'>
If we don't want to show all values on the x-axis, we can specify the smallest and the largest bin edges using binrange
.
plt.figure(figsize=(8, 5))
# limit the histogram to between 0 and the maximum value of deathIncrease
sns.histplot(x = "deathIncrease", data = covid,
binrange = (0, max(covid["deathIncrease"])))
<AxesSubplot:xlabel='deathIncrease', ylabel='Count'>
We can overlay a smoothed density curve on the histogram by setting kde = True
. Note that since we are adding a density curve, it only makes sense if we put density
on the y-axis of the histogram.
plt.figure(figsize=(8, 5))
# histogram with density on the y-axis and kernel density estimate overlayed
sns.histplot(x = "deathIncrease", data = covid, stat = "density", kde = True)
<AxesSubplot:xlabel='deathIncrease', ylabel='Density'>
To change the color of the bins, we can specify color
.
plt.figure(figsize=(8, 5))
# histogram with a different color
sns.histplot(x = "deathIncrease", data = covid, color = "darksalmon")
<AxesSubplot:xlabel='deathIncrease', ylabel='Count'>
For a comprehensive list of arguments for histplot
, check out seaborn.histplot.
Similar to histograms, a density curve is useful if we want to see the shape of the distribution. kdeplot
generates a kernel density estimate curve using Gaussian kernels for the given data.
plt.figure(figsize=(8, 5))
# kernal density estimate
sns.kdeplot(x = "deathIncrease", data = covid)
<AxesSubplot:xlabel='deathIncrease', ylabel='Density'>
We can change the color of the curve using color
. To fill the area under curve, we set shade = True
. Seaborn will automatically use a filling color that is of the same shade with the color of the curve.
plt.figure(figsize=(8, 5))
# kernal density estimate with area under curve filled
sns.kdeplot(x = "deathIncrease", data = covid,
color = "darksalmon", fill = True)
<AxesSubplot:xlabel='deathIncrease', ylabel='Density'>
We can specify the minimum and maximum value on the x-axis in a density curve plot using clip
.
plt.figure(figsize=(8, 5))
# limit the density curve to between 0 and the maximum value of deathIncrease
sns.kdeplot(x = "deathIncrease", data = covid,
clip = (0, max(covid["deathIncrease"])))
<AxesSubplot:xlabel='deathIncrease', ylabel='Density'>
To change the smoothing bandwidth, we use the bw_method
argument. This argument will be passed to the gaussian_kde
function in Scipy for further calculation. For more information, check out scipy.stats.gaussian_kde.
We can also use bw_adjust
to control the level of smoothing. Larger values correspond to more smoothing and vice versa.
plt.figure(figsize=(8, 5))
# less smoothing
sns.kdeplot(x = "deathIncrease", data = covid, bw_adjust = 0.5)
<AxesSubplot:xlabel='deathIncrease', ylabel='Density'>
plt.figure(figsize=(8, 5))
# more smoothing
sns.kdeplot(x = "deathIncrease", data = covid, bw_adjust = 3)
<AxesSubplot:xlabel='deathIncrease', ylabel='Density'>
The kernel density curve we get from kdeplot
estimates the probability density function of the original distribution, but we can also estimate the cumulative distribution function by setting cumulative = True
.
plt.figure(figsize=(8, 5))
# estimate the cumulative density
sns.kdeplot(x = "deathIncrease", data = covid, cumulative = True)
<AxesSubplot:xlabel='deathIncrease', ylabel='Density'>
We used to be able to choose non-gaussian kernels for density estimate, but unfortunately it's no longer supported. If we want to use non-Gaussian kernels, we can either use statsmodels or scikit-learn.
For a comprehensive list of arguments for kdeplot
, check out seaborn.kdeplot.
The two methods above are good for visualizing the shape of the distribution, but not statistics such as median and outliers. If we want to visualize the "summary" statistics of a distribution and do not care about the shape, we can resort to box plots.
A box plot shows the "five-number summary" of a distribution, which includes the minimum, the first quartile (25% quantile), the median, the third quartile (75% quantile), and the maximum. It also marks the outliers separately.
plt.figure(figsize=(8, 4))
# box plot - horizontal
sns.boxplot(x = "deathIncrease", data = covid)
<AxesSubplot:xlabel='deathIncrease'>
In the plot above, the line inside the box is the median. The left and right edge of the box is the first and third quartile, respectively. The two lines extending out from the box are called whiskers. The endpoint of the left whisker is either the minimum value, or Q1 - 1.5IQR (the first quartile minus 1.5 times the interquartile range), in other words the minimum value that is not considered an outlier. Similarly, the endpoint of the right whisker is either the maximum value, or Q3 + 1.5 IQR. The points outside the whiskers are outliers.
We can flip the box plot easily by specifying the data as argument y
, instead of x
. Needless to say, we can change the color of the box using color
.
plt.figure(figsize=(4, 6))
#box plot - vertical
sns.boxplot(y = "deathIncrease", data = covid, color = "darksalmon")
<AxesSubplot:ylabel='deathIncrease'>
For a comprehensive list of arguments for boxplot
, check out seaborn.boxplot.
A strip plot is a variation of a dot plot, which plots every single observation. The density of a value is visualized as the literal density of points on the plot at a given x value. The x-coordinate of each point is its value of variable on the x-axis, but the y-coordinate is meaningless. The points are spread out on the y-axis only to make the points overlap less.
plt.figure(figsize=(8, 4))
# strip plot
sns.stripplot(x = "deathIncrease", data = covid)
<AxesSubplot:xlabel='deathIncrease'>
We can control how spread out the points are by setting the jitter
argument. jitter = 0
means no spread at all, in this case it means all points will be on the same horizontal line. jitter = 1
or jitter = True
is the default value used. Any other value in range [0, 1) represents the amount of jitter, so larger value means points that are more spread out.
plt.figure(figsize=(8, 4))
# strip plot with more jitter
sns.stripplot(x = "deathIncrease", data = covid, jitter = 0.3)
<AxesSubplot:xlabel='deathIncrease'>
A strip plot alone doesn't contain much information. However, we can overlay it on a box plot as a complement, since a box plot does not visualize density. In the plot below, the color and transparency of the points in the strip plot are changed using color
and alpha
so that we can still see the box plot.
plt.figure(figsize=(8, 4))
# strip plot on top of box plot
sns.boxplot(x = "deathIncrease", data = covid)
sns.stripplot(x = "deathIncrease", data = covid,
color = "darksalmon", jitter = 0.4, alpha = 0.4)
<AxesSubplot:xlabel='deathIncrease'>
We can still flip the coordinates of a strip plot by specifying the data as the y
argument.
plt.figure(figsize=(4, 6))
# strip plot on top of box plot, coordinates flipped
sns.boxplot(y = "deathIncrease", data = covid)
sns.stripplot(y = "deathIncrease", data = covid,
color = "darksalmon", jitter = 0.4, alpha = 0.4)
<AxesSubplot:ylabel='deathIncrease'>
For a comprehensive list of arguments for stripplot
, check out seaborn.stripplot.
A swarm plot is similar to a strip plot, but it spreads the points out more at values with higher densities. Note that swarm plots are extremely unscalable. In the code below, we only plot the first 300 observations. If we include more points than can be placed on the plot, we will receive a warning to either decrease the point size or use a strip plot instead. We can control the point size using the size
argument.
plt.figure(figsize=(8, 4))
# swarm plot
sns.swarmplot(x = "deathIncrease", data = covid.iloc[0:300,])
<AxesSubplot:xlabel='deathIncrease'>
We can flip the coordinates of a swarm plot or overlay it on top of a box plot using the same methods mentioned earlier.
plt.figure(figsize=(4, 6))
# swarm plot on top of box plot, coordinates flipped
sns.boxplot(y = "deathIncrease", data = covid)
sns.swarmplot(y = "deathIncrease", data = covid.iloc[0:300,], color = "darksalmon", alpha = 0.5)
<AxesSubplot:ylabel='deathIncrease'>
For a comprehensive list of arguments for swarmplot
, check out seaborn.swarmplot.
A violin plot can be seen as the combination of a box plot and density curves. By adding the kernel density estimate to each side of the (modified) box plot, it makes up for a box plot's disadvantage of not showing the shape of the underlying distribution.
plt.figure(figsize=(8, 4))
# violin plot
sns.violinplot(x = "deathIncrease", data = covid)
<AxesSubplot:xlabel='deathIncrease'>
The inner
argument of violinplot
controls how the summary information are visualized. By default, it has value "box"
, which results in a mini box plot inside the violin. We can also change it to "quartile"
.
plt.figure(figsize=(8, 4))
# violin plot
sns.violinplot(x = "deathIncrease", data = covid, inner = "quartile")
<AxesSubplot:xlabel='deathIncrease'>
For a comprehensive list of arguments for violinplot
, check out seaborn.violinplot.
Recall that our data comes from three states. What if we want to visualize the distribution of deathIncrease
conditioned on each state and compare them? In this section, we will see how to easily achieve this by modifying what we have plotted above.
One of the tricks here is to use the hue
argument in most of seaborn's plotting functions. By passing in a categorical variable to hue
, we are essentially telling seaborn that we want to color our plot according to the value of this categorical variable. Another trick is to separate a plot into different subplots.
In the examples below, we will visualize the distribution of deathIncrease
conditioned on state
.
We can change our histograms in several ways to reflect the additional dimension of data, the categorical variable state
. First of all, we can seperate each bin into different "sections", each section corresponding to one category of state
. All we need to do is to pass the categorical variable into hue
, and change multiple
as needed.
By default, the value of multiple
is "layer"
, which means bins of different categories are layered together and all start from y = 0.
plt.figure(figsize=(8, 5))
# layered histogram
sns.histplot(x = "deathIncrease", data = covid, hue = "state")
<AxesSubplot:xlabel='deathIncrease', ylabel='Count'>
If we want the bins to stack together instead of overlapping, we can set multiple = "stack"
. Beware using stacked histograms though - since bars of some colors do not start at the bottom of y-axis, things can get tricky when we try to see the conditional distribution of categories that are not at the bottom. It is also not very useful in terms of comparing conditional distribution given different categories.
plt.figure(figsize=(8, 5))
# stacked histogram
sns.histplot(x = "deathIncrease", data = covid,
hue = "state", multiple = "stack")
<AxesSubplot:xlabel='deathIncrease', ylabel='Count'>
We can also make every bin have the same height, 1, which means what is on the y-axis now is the conditional probability of different values of state
given a bin of values for deathIncrease
. To do this, we use multiple = "fill"
. By default, the y-axis label is count, which is not what we have on the y-axis here, so we need to manually change the label using the function set
.
However, this kind of histogram can be confusing at times because we cannot really see the conditional distribution of deathIncrease
under each category. What we actually see is the conditional distribution of state
given a value of deathIncrease
. This can be counterintuitive, so if you choose to use it, be sure to make it clear in the description so that you readers won't misinterpret it as the other way around.
plt.figure(figsize=(8, 5))
# histogram showing conditional probability of each category at each bin
sns.histplot(x = "deathIncrease", data = covid,
hue = "state", multiple = "fill").set(ylabel = "conditional probability")
[Text(0, 0.5, 'conditional probability')]
Another option is to facet the histogram, i.e. split one histogram into different subplots, each for a category. We do this by 1) using the displot
function, which facets plots for distributions, and 2) passing the categorical variable into the row
or col
argument, depending on if we want to facet the plots on rows or on columns.
We can add two categorical variables to a faceted histogram by passing one to row
and one to col
. This will result in the histogram being facetted on each combination of categories of these two categorical variables.
# facetted histogram on two categorical variable
sns.displot(x = "deathIncrease", data = covid,
row = "hospitalizedLevel", col = "state", height = 3)
<seaborn.axisgrid.FacetGrid at 0x7faad4c22fa0>
To plot a separate curve of the given quantitative variable for each category, we still use hue
.
Note that here, each curve is not really a density curve, because the area under curve for each curve is not 1. It is the total area under curve for all curves that equals to 1. Therefore, what we have here is not really the conditional density. If we want the plot the conditional density of deathIncrease
given state = MA
instead, we would pass in a subset of the data that only contains observations with state = MA
to our function call.
plt.figure(figsize=(8, 5))
# kernal density estimate
sns.kdeplot(x = "deathIncrease", data = covid, hue = "state")
<AxesSubplot:xlabel='deathIncrease', ylabel='Density'>
Same as what we did for the histograms, we can set multiple
to "stack"
or "fill"
if we prefer. Just make sure you have kept their limitations in mind. If we want to facet the density curves, we can use the exact same code for facetting histograms, except that we use kind = "kde"
to tell seaborn that we want the kernel density estimates, not the histograms, which will be plotted by default.
However, note that what we did below is in general not a good example of data visualization, because people may accidentally think that each plot is a conditional density curve, when in fact they are not (the area under curve for each plot is not 1).
# facetted density curve
sns.displot(x = "deathIncrease", data = covid,
col = "state", kind = "kde")
<seaborn.axisgrid.FacetGrid at 0x7faaf6715df0>
Adding another categorical variable to a box plot is a common practice, since we can easily compare the quartiles of the given quantitative variable under different categories by putting multiple boxes side by side. To achieve this, we put the quantitative variable on one of the axes, and the categorical variable on the other.
plt.figure(figsize=(6, 8))
# side-by-side box plot
sns.boxplot(x = "state", y = "deathIncrease", data = covid)
<AxesSubplot:xlabel='state', ylabel='deathIncrease'>
We can add another categorical variable to the box plot by specifying hue
.
plt.figure(figsize=(10, 8))
# side-by-side box plot grouped by two categorical variables
sns.boxplot(x = "state", y = "deathIncrease", data = covid,
hue = "hospitalizedLevel")
<AxesSubplot:xlabel='state', ylabel='deathIncrease'>
To facet a box plot into different subplots, we use catplot
, which facets categorical plots.
# facetted box plots
sns.catplot(x = "state", y = "deathIncrease", data = covid,
col = "hospitalizedLevel", kind = "box", width = 0.5)
<seaborn.axisgrid.FacetGrid at 0x7faad8f58a90>
Similarly, if we want to facet a strip/swarm/violin plot, we can use the same code with the kind
argument set to "strip"
, "swarm"
, or "violin"
, respectively. We will skip the demonstration since the code is mostly the same.
When we have a categorical variable, we may be interested in the distribution of categories and comparisons among categories. We can use plots with different areas, each of which corresponds to one category, and the areas are usually proportionate to the count/percentage of each category.
A bar plot to categorical data is similar to a histogram to quantitative data. On the x-axis are different categories, and on the y-axis is either the count or the percentage of a category. Assume that we are plotting horizontally (i.e. the bars are aligned horizontally), then the bars will have equal widths, and their heights are proportionate to the count/percentage.
plt.figure(figsize=(8, 5))
# bar plot
sns.countplot(x = "hospitalizedLevel", data = covid)
<AxesSubplot:xlabel='hospitalizedLevel', ylabel='count'>
Bar plots allow us to compare the counts among categories easily by comparing their heights. If we want the bars to align vertically, we just pass in the variable to y
, instead of x
.
The categorical variable we have here is ordinal, which means the categories (levels) have ordering. Our bars in the plot above are not ordered reasonably, so we would want to rearrange the barscusing order
.
(In fact, a better way of solving this once and for all is to specify level order in our data frame using pandas' functions. But since that's technically part of data-preprocessing, we will not talk about that here, and will keep using order
for demonstration purposes.)
plt.figure(figsize=(8, 5))
# bar plot with levels rearranged
sns.countplot(x = "hospitalizedLevel", data = covid,
order = ["low", "medium", "high"])
<AxesSubplot:xlabel='hospitalizedLevel', ylabel='count'>
If we want to put percentage on the y-axis instead of count, things can get a little bit messy. There are other libraries with which we can do this more easily, such as Dexplot. But if we really want to use seaborn for it, there are several workarounds. Firstly, we can make another data frame to store each category and its percentage, pass that data frame into another function called barplot
, and specify the category as x
and the percentage as y
. If we want proportion on the y-axis, we can also use histplot
with stat = "probability"
and discrete = True
, as shown below. However, we cannot really change the ordering on the x-axis, so this workaround won't always make sense for ordinal variables.
plt.figure(figsize=(8, 5))
# bar plot with percentage
sns.histplot(x = "hospitalizedLevel", data = covid,
stat = "probability", discrete = True)
<AxesSubplot:xlabel='hospitalizedLevel', ylabel='Probability'>
For a comprehensive list of arguments for countplot
, check out seaborn.countplot.
You must have seen a pie chart somewhere - I recall myself seeing so many pie charts that I feel like it's unfair to not talk about it here. Seaborn does not support pie charts, so if you really want to make one, you need to resort to matplotlib
. But this lack of feature is definitely not coming out of nowhere.
In a pie chart, each pie has equal radius, and it's the angles that are proportionate to the counts, which means to compare different categories, what we are comparing is the angle of different slices. This becomes a problem because people are in general not good at comparing angles. When there are close counts across categories, it becomes unnecessarily hard to tell at first sight which one is the biggest/smallest.
Let's use some dummy numbers to demonstrate a (very) bad usage of pie charts. Suppose we have three categories whose percentage is 33%, 36%, and 31%, respectively. We then use matplotlib
to make a pie chart to visualize the three percentages. The labels are left out intentionally - without any label and legends, can you tell the color corresponding to each percentage? In other words, can you tell which slice is the biggest or the smallest?
plt.figure(figsize=(5, 5))
# pie chart - demonsration of bad usage
dummy_data = [33, 36, 31]
plt.pie(dummy_data)
([<matplotlib.patches.Wedge at 0x7faadac52580>, <matplotlib.patches.Wedge at 0x7faadac52a60>, <matplotlib.patches.Wedge at 0x7faadac52ee0>], [Text(0.5599455183205815, 0.9468162527717273, ''), Text(-1.0978293924762963, -0.06906971127148508, ''), Text(0.6182918791840069, -0.9097885205557991, '')])
You may ask that, well, if it's hard to compare the pies by themselves, why not just overlay the percentage on top of each pie to avoid confusions? Doing so of course makes a pie chart less confusing, but at that point we are just relying on the numbers themselves to make comparisons, which defeats the purpose of data visualization. If we are not helping the readers make comparisons by providing intuitive visuals, why not just use a contingency table with percentage listed?
Still, there may be cases where the use of a pie chart can be justified. However, you can almost always use a bar plot instead of a pie chart to fulfill the same mission, with easier comparisons and less confusions.
If we want something like a bar plot, but with a point representing the count of each category and lines connecting the points instead of bars, we can use a point plot. The pointplot
function in seaborn is for visualizing a categorical variable and some estimator of a quantitative variable. Therefore, to make a point plot of counts we need to aggregate our data first using pandas' groupby
function, so that we have a variable (which contains the count of each category) to pass into pointplot
.
plt.figure(figsize=(8, 5))
# group the data by hospitalizedLevel, then find the sum of observations under each group
covid_temp = covid.groupby("hospitalizedLevel", as_index = False).size()
# point plot
sns.pointplot(x = "hospitalizedLevel", y = "size", data = covid_temp,
order = ["low", "medium", "high"]).set(ylabel = "count")
[Text(0, 0.5, 'count')]
A point plot provides the same information compared to a bar plot, except for two differences. Firstly, we can see the trend of counts as hospitalizedLevel
increases more clearly. In this case, the higher hospitalizedLevel
is, the fewer observations there are.
Lines are often interpreted to be related to the trend of data, which makes a point plot an ideal choice only if the categorical variable on the x-axis is ordinal, i.e. when the categories have orders. If there is no ordering on the x-axis, it doesn't make much sense to use a point plot, since the line segments can be misleading by making it seem like you are trying to show some trend across categories.
Secondly, a point plot is less compact than a bar plot since it doesn't have bars to take up most of the spaces. This trait can be useful when we want to add another categorical variable. In this case, having three sets of points and lines instead of three sets of bars may result in a cleaner visual. We will discuss more about this in the next section.
For a comprehensive list of arguments for pointplot
, check out seaborn.pointplot.
In section 4 we have seen how to add more categorical variables to a plot visualizing 1D quantitative data by facetting the plot, using different colors and stacking the elements or putting them side-by-side, etc. Here we will do the same thing and see some variations of the plots we just introduced in section 5.
Let's use the same tricks to visualize the distribution of hospitalizedLevel
conditioned on state
.
To show the count under each combination of hospitalizedLevel
and state
, we have several options: a side-by-side bar plot, a stacked bar plot, or a faceted bar plot.
To make a side-by-side bar plot, we still use hue
.
plt.figure(figsize=(8, 5))
# side-by-side bar plot
sns.countplot(x = "hospitalizedLevel", data = covid,
order = ["low", "medium", "high"], hue = "state")
<AxesSubplot:xlabel='hospitalizedLevel', ylabel='count'>
In the plot above, we can see two things clearly - the conditional distribution of hospitalizedLevel
under different state
, and the conditional distribution of state
under different hospitalizedLevel
. However, it is harder to see the marginal distribution of hospitalizedLevel
or that of state
alone.
If we use a stacked bar plot instead, we will be able to see the marginal distribution of hospitalizedLevel
, but may not be able to see conditional distributions as easily, which defeats the purpose of adding another categorical variable to the bar plot. For this season, the developer of seaborn didn't develop the feature of stacked bar plots. If we want to make one, we can use matplotlib.
Seaborn does support faceted bar plot, achieved by calling catplot
like what we did to make a faceted box plot, but this time we only specify the x
argument, and set kind = "count"
. By setting col = "state"
, we get a separate bar plot for each state
, as shown below.
# facetted bar plot
sns.catplot(x = "hospitalizedLevel", data = covid,
order = ["low", "medium", "high"], col = "state", kind = "count")
<seaborn.axisgrid.FacetGrid at 0x7faaf8a4bdc0>
There is not really a "stacked" or "side-by-side" point plot since we can just have multiple sets of lines in the same point plot. We still aggregate the data first, then use hue
as usual.
plt.figure(figsize=(8, 5))
# group the data by hospitalizedLevel and state, then find the sum of observations under each group
covid_temp = covid.groupby(["hospitalizedLevel", "state"], as_index = False).size()
# point plot with multiple layers
sns.pointplot(x = "hospitalizedLevel", y = "size", data = covid_temp,
order = ["low", "medium", "high"], hue = "state").set(ylabel = "count")
[Text(0, 0.5, 'count')]
Using dodge
, we seperate the points of different state
at each hospitalizedLevel
to make overlapping points visible.
plt.figure(figsize=(8, 5))
# point plot with multiply layers, dodged
sns.pointplot(x = "hospitalizedLevel", y = "size", data = covid_temp,
order = ["low", "medium", "high"], hue = "state", dodge = True).set(ylabel = "count")
[Text(0, 0.5, 'count')]
A faceted point plot may not seem very meaningful in this case, but it becomes very useful if we are adding more than one categorical variables. Here, we aggregate the data by both categorical variables first to find out the count. Then, we pass our new data frame into catplot
, specify col = "state"
and kind = "point"
.
# group the data by hospitalizedLevel and state, then find the sum of observations under each group
covid_temp = covid.groupby(["hospitalizedLevel", "state"], as_index = False).size()
# facetted point plot
sns.catplot(x = "hospitalizedLevel", y = "size", data = covid_temp,
order = ["low", "medium", "high"], col = "state", kind = "point").set(ylabel = "count")
<seaborn.axisgrid.FacetGrid at 0x7faae7f0ba30>
Now, let's see how to visualize the relationship between two quantitative variables. We will talk about incorporating more variables in addition to the two quantitative variables in this section as well. In the examples below, we will visualize the relationship between deathIncrease
and hospitalizedCurrently
.
In a scatter plot, every observation shows up as a point, whose coordinates are determined by the values of the two variables. A scatter plot can provide a quick overview of how the two variables of interest may be related, but beware that any graph alone is not enough to conclude the existence or type of relationship.
plt.figure(figsize=(8, 5))
# scatter plot
sns.scatterplot(x = "hospitalizedCurrently", y = "deathIncrease", data = covid)
<AxesSubplot:xlabel='hospitalizedCurrently', ylabel='deathIncrease'>
When we have so many data points that a scatter plot becomes too compact, we can always select a random subset of the data to visualize. (We can also decrease the point size, although there is no easy way to do this for scatter plots in seaborn.)
In a scatterplot, we can easily visualize the relationship between two quantitative variables conditioned on one or more variables by using points of different colors, sizes, shapes, or a combination of more than one aesthetics mentioned here.
In the case of a scatter plot, these aesthetics can be changed based on either categorical or quantitative variables. In the examples below, as we talk about three arguments that control these aesthetics, we will still use a categorical variable, but just keep in mind that in scatterplot
, you can pass in either categorical or numerical variable for hue
and size
(but not for style
).
To change the color of the points, we still use hue
. We also decrease point transparency to make overlapping points visible.
plt.figure(figsize=(8, 5))
# scatter plot, points colored by state
sns.scatterplot(x = "hospitalizedCurrently", y = "deathIncrease", data = covid,
hue = "state", alpha = 0.7)
<AxesSubplot:xlabel='hospitalizedCurrently', ylabel='deathIncrease'>
We can also use style
to change the shape of points according to a categorical variable.
plt.figure(figsize=(8, 5))
# scatter plot, shaped by hospitalizedLevel
sns.scatterplot(x = "hospitalizedCurrently", y = "deathIncrease", data = covid,
style = "hospitalizedLevel", alpha = 0.7)
<AxesSubplot:xlabel='hospitalizedCurrently', ylabel='deathIncrease'>
To control the size of points according to a variable, either categorical or numerical, we use size
.
plt.figure(figsize=(8, 5))
# scatter plot, sized by totalTestResultsIncrease
sns.scatterplot(x = "hospitalizedCurrently", y = "deathIncrease", data = covid,
size = "totalTestResultsIncrease", alpha = 0.7)
<AxesSubplot:xlabel='hospitalizedCurrently', ylabel='deathIncrease'>
We can of course control these aesthetics together in one plot, essentially adding as many as three additional dimensions of information. However, the plot will usually get cluttered with that much information, and the plot soon becomes too overwhelming to be interpreted easily.
We can facet a scatter plot by using relplot
. Although not shown in the code below, we can further pass in the aesthetics arguments we discussed above (hue
, style
, and size
) to add more variables to each subplot.
# facetted scatter plot
sns.relplot(x = "hospitalizedCurrently", y = "deathIncrease", data = covid, col = "state")
<seaborn.axisgrid.FacetGrid at 0x7faaf6f5a8e0>
To overlay a linear regression fit on top of a scatter plot, we use another function, regplot
. By default, a confidence interval is shown. We can control its level by passing in a number in [0, 100] to ci
, or set it to None
if we don't want it to show.
plt.figure(figsize=(8, 5))
# scatter plot with linear regression fit
sns.regplot(x = "hospitalizedCurrently", y = "deathIncrease", data = covid, ci = 90)
<AxesSubplot:xlabel='hospitalizedCurrently', ylabel='deathIncrease'>
We can also overlay a LOWESS fit (locally weighted smoothing) instead, using the exact same syntax plus lowess = True
. The model estimation is done using another library, statsmodels
, so we will need to install and import it before we can draw a LOWESS curve.
If we only want the linear regression or LOWESS fit line/curve, we use the same code with scatter = False
.
For a comprehensive list of arguments for scatterplot
and regplot
, check out seaborn.scatterplot and seaborn.regplot.
It's not very easy to see joint distribution of two variables in a scatter plot, because the density of the points is hard to tell. We can use a contour plot instead, where each contour line represents positions with the same density.
We will use kdeplot
, the exact function we use to draw 1D density curve, and pass in a y
variable.
plt.figure(figsize=(8, 5))
# contour plot
sns.kdeplot(x = "hospitalizedCurrently", y = "deathIncrease", data = covid)
<AxesSubplot:xlabel='hospitalizedCurrently', ylabel='deathIncrease'>
plt.figure(figsize=(8, 5))
# contour plot with contour lines filled
sns.kdeplot(x = "hospitalizedCurrently", y = "deathIncrease", fill = True, data = covid)
<AxesSubplot:xlabel='hospitalizedCurrently', ylabel='deathIncrease'>
We can control the lowest density at which to draw a contour line using thresh
, whose default value is 0.05. To change the color of contour lines, we can use color
.
Just like other functions, we incorporate another categorical variable into a contour plot using hue
.
plt.figure(figsize=(8, 5))
# contour plot conditioned on state
sns.kdeplot(x = "hospitalizedCurrently", y = "deathIncrease", data = covid, hue = "state")
<AxesSubplot:xlabel='hospitalizedCurrently', ylabel='deathIncrease'>
An alternative to a contour plot is a heat map, where the density of a given (x, y) value or range is represented by colors of different hues or intensities.
Unfortunately, making a heat map in seaborn may require some efforts. The heatmap
function in seaborn takes in a 2D data set, where the row and column are the variables we want to put on the axes of a heat map, and each entry corresponds to the density/frequency/magnitude of that corresponding (x, y) value. This is called wide-form data, where each row/column represents a value/level for the variable on the row/column (as opposed to long-form data, which is what we have here, where each row is an observation and each column is a variable).
Most of the seaborn functions work best with long-form data, except heatmap
. Therefore, if we have long-form data at hand, and want to use seaborn to make a heat map, we need to change the data into wide-form first. For demonstration purpose, we round hospitalizedCurrently
to the nearest 1000, and deathIncrease
to the nearest 10.
plt.figure(figsize=(8, 5))
# round the columns
covid[["hcRounded"]] = covid[["hospitalizedCurrently"]].apply(lambda x: x//1000*1000, axis = 1)
covid[["diRounded"]] = covid[["deathIncrease"]].apply(lambda x: x//10*10, axis = 1)
# group by the rounded values and find the count for each combination of values
covid_rounded = covid.groupby(["hcRounded", "diRounded"], as_index = False).size()
# change from long to wide format
new = covid_rounded.pivot(index = "hcRounded", columns = "diRounded", values = "size")
# fill NaNs with zeros
new = new.fillna(0)
# keep part of the data frame (matrix) with mostly non-zero values
new = new.iloc[:10, :10]
# heat map
sns.heatmap(new, linewidths = 1, cmap = "flare")
<AxesSubplot:xlabel='diRounded', ylabel='hcRounded'>
In the call to heatmap
above, we use linewidths
to specify the width of borders between adjacent cells, while cmap
is used to specify a color palette of our choice.
The data we have is probably not a perfect example for demonstrating a heat map since most of the cells have very low densities. Still, we can see that the joint distribution of hospitalizedCurrently
and deathIncrease
has the highest density at around 0-20 deathIncrease
and 0-1000 hospitalized currently.
For a comprehensive list of arguments for heatmap
, check out seaborn.heatmap.
If we have a series of data points, we can use a line plot. Observations in a line plot are usually ordered by the variable on the x-axis, which is often a time series, and the line allows us to visualize the trend of the y-axis variable as the x-axis variable increases. Our data frame did record the date
of each observation under each state, so now let's visualize deathIncrease
over data
. For simplicity, we use a subset of the data.
plt.figure(figsize=(8, 5))
# keep data between 2020-04-01 to 2020-09-30
covid_subset = covid[(covid["date"] >= "2020-04-01") & (covid["date"] <= "2020-09-30")]
# line plot for time series data
sns.lineplot(x = "date", y = "deathIncrease", data = covid_subset)
<AxesSubplot:xlabel='date', ylabel='deathIncrease'>
In our data, there are three observations under each date
, one for each state that we have. When there are multiple y values at the same x value, seaborn automatically "aggregates" all y values at any x and shows a confidence interval for the aggregated value. The default estimator when aggregating these y values is their mean, but we can pass in a pandas method to the argument estimator
to tell seaborn how we want different y values at the same x to be aggregated. If we don't want the confidence interval, we just need to set ci = None
.
On the other hand, we can draw a separate line for each category instead of aggregating the values.
plt.figure(figsize=(8, 5))
# line plot for each state
sns.lineplot(x = "date", y = "deathIncrease", data = covid_subset, hue = "state")
<AxesSubplot:xlabel='date', ylabel='deathIncrease'>
We can further make each line have a different line type or width by setting style
or size
to the categorical variable (in this case state
), respectively.
To facet a line plot, we use the same code as the faceted scatter plot, but with an additional argument kind = "line"
(the default value is kind = "scatter"
).
# facetted line plot
sns.relplot(x = "date", y = "deathIncrease", data = covid_subset,
col = "state", kind = "line")
<seaborn.axisgrid.FacetGrid at 0x7faaea8940d0>
For a comprehensive list of arguments for lineplot
, check out seaborn.lineplot.
In the sections above, we have seen how to visualize 1D categorical and quantitative data, 2D quantitative data, and how to add more dimensions of categorical data by using one or more of the following methods: 1) facet the plot, 2) have a separate set of elements (lines, points, bars, etc.) for each category, 3) use different colors/shapes/sizes for each category, or for each numerical value (in the case of a scatter plot). These methods allow us to visualize high dimensional data, although we may not want to put all these information in the same plot. The advantage of data visualization is an intuitive way of representing and understanding the data, and adding too much information may destroy that.
Things get trickier when we have many quantitative variables. We did talk about incorporating at most four quantitative variables in a scatter plot, although doing that sacrifices the interpretability of the plot. If we want to visualize 3D quantitative data, the option of making a 3D scatter plot or contour plot is there (there are of course more choices of 3D graphs). But it can be harder to observe a 3D plot, since sometimes the pattern can only be seen clearly from certain angles, and it's hard to observe everything from only one angle. Depending on the context, we can consider making an interactive or dynamic 3D plot, in which case matplotlib is the better choice. We can continue to add two more quantitative variables to a 3D scatter plot (by changing the color and size of points), although at that point the plot is most likely unreadable, unless the sample size is so small that the readers can still see the variations in points clearly.
As said again and again, avoid stuffing too much information into the same plot. If we have more dimensions of data that can be reasonably fit in one plot, make several. Try different combinations of variables, and see which ones, when visualized in the same plot, present information in the clearest way. Depending on the data and the plot, there may also be different ways of incorporating variables in the plot than those mentioned in this tutorial. But no matter what visualization method you choose, the rule of thumb is 1) be clear about what you are trying to show with your plots and 2) always keep readability in mind.
Plot aesthetics is important when it comes to making effective and readable graphs, especially when the dimension of information is high. We didn't talk much about plot aesthetics other than changing the color of elements in the plot, but seaborn offers a powerful set of tools to customize the plot's aesthetics. If you are interested, you can check out the tutorials on seaborn's official site.
In the last section of this tutorial, let's talk about seaborn, matplotlib, and other visualization libraries in python.
First of all, seaborn is built on matplotlib, and it's like a "high-level interface" of matplotlib. Compared to matplotlib, seaborn has the advantages of being 1) easier to learn and use since it has simpler syntax, 2) more nice-looking with fewer efforts needed, and 3) better integrated with pandas data frames.
In seaborn, you can make complicated plots in a few lines of code - if it is supported (here is a list of all the supported graphs in seaborn). However, seaborn's convenience does not come at no price. When the feature you need is not supported (such as stacked bar plots and pie charts), it's hard to find a workaround using seaborn's functions only.
In matplotlib, it's straightforward to make basic plots, but things get harder quickly when you try to make more complex plots. But, it's also highly versatile and customizable. There is almost nothing that you can't plot using matplotlib, as long as you can figure out the way. It is also more user-friendly for MATLAB users since its usage is similar to plotting in MATLAB, thus providing a smoother transition from MATLAB to Python.
There are numerous other choices for data visualization in python. For R users, ggplot may be more user-friendly due to the similarity of its syntax and that of ggplot2 in R. Note that ggplot is not the Python version of ggplot2 - it's just that they are both implementations of Leland Wilkinson's Grammar of Graphics, thus having many overlapping features and syntax.
There is also plotly, the Python interface of plotly.js, which is a browser-based visualization library. It is compatible with both Jupyter Notebook and many web browsers, and creates highly interactive graphs. Another visualization library that is compatible with browsers is bokeh. There are still many, many more to be mentioned. Don't feel overwhelmed by the vast number of choices here, since you will naturally know more about the advantages and disadvantages of each when different use cases motivate you to learn new libraries as you go.