download notebook
view notebook w/o solutions
More on plotting
files needed = ('chile.xlsx', 'broadband_size.xlsx', 'auto_data.dta')
The seaborn package
We have been using matplotlib to plot. This is a very low-level package, meaning that we have a lot of fine-grained control of the elements of our plots. I like fine-grained control because it means I can makes things look exactly how I want it. It also means that we have to type a lot of code to create a figure.
In this notebook we introduce the seaborn package. Seaborn is written on top of matplotlib and automates some of the creation of plots. This can be very helpful, but remember do not trust the defaults.
Seaborn also includes some plot types that are not easy to do in matplotlib. Great!
We will cover:
* regplot()
which adds the line-of-best-fit to a scatter plot
* jointplot()
which adds the marginal distributions to the axes
* Scaling scatter markers with a third variable
* Facet plots
Along the way, we will take a look at the World Development Indicators.
import pandas as pd # our go-to for data handling
import matplotlib.pyplot as plt # make plots, but doesn't automate much
import seaborn as sns # some new plot types, more automation
pd.set_option('precision', 3) # this tells pandas to print out 3 decimal places when we print a DataFrame
World Development Indicators
The World Bank's World Development Indicators is a great source of economic and social data for many countries and many years. I have already downloaded output and consumption data for a few countries and saved it as 'wdi.csv'.
The database is great: lots of variables, countries and time. I have extracted a small data file for Chile. The files that come out of the download facility are a mess. Let's clean them up.
# I looked at the workbook and noticed the footer and that nas are '..'.
wdi = pd.read_excel('chile.xlsx', na_values='..', skipfooter=5)
wdi.head()
Country Name | Country Code | Series Name | Series Code | 1960 [YR1960] | 1961 [YR1961] | 1962 [YR1962] | 1963 [YR1963] | 1964 [YR1964] | 1965 [YR1965] | ... | 2011 [YR2011] | 2012 [YR2012] | 2013 [YR2013] | 2014 [YR2014] | 2015 [YR2015] | 2016 [YR2016] | 2017 [YR2017] | 2018 [YR2018] | 2019 [YR2019] | 2020 [YR2020] | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | Chile | CHL | GDP (constant LCU) | NY.GDP.MKTP.KN | 1.594e+13 | 1.678e+13 | 1.745e+13 | 1.847e+13 | 1.894e+13 | 1.912e+13 | ... | 125823838388000 | 1.325e+14 | 137876215768070 | 1.403e+14 | 1.435e+14 | 1.460e+14 | 147736095622350 | 153570668110240 | 155189982580250 | NaN |
1 | Chile | CHL | Final consumption expenditure (constant LCU) | NE.CON.TOTL.KN | 1.336e+13 | 1.402e+13 | 1.461e+13 | 1.515e+13 | 1.514e+13 | 1.529e+13 | ... | 93785024589700 | 9.909e+13 | 103336792108318 | 1.063e+14 | 1.090e+14 | 1.128e+14 | 116887472687087 | 121369080634627 | 122348544930031 | NaN |
2 rows × 65 columns
There is a lot to not like about this DataFrame.
- There are unneeded variables.
- The year data are a mix of numbers and letters.
- The unit of observation is a year, so the I want the years in the index.
1. Drop unneeded variables
wdi = wdi.drop(['Country Name', 'Country Code', 'Series Code'], axis=1)
wdi.head()
Series Name | 1960 [YR1960] | 1961 [YR1961] | 1962 [YR1962] | 1963 [YR1963] | 1964 [YR1964] | 1965 [YR1965] | 1966 [YR1966] | 1967 [YR1967] | 1968 [YR1968] | ... | 2011 [YR2011] | 2012 [YR2012] | 2013 [YR2013] | 2014 [YR2014] | 2015 [YR2015] | 2016 [YR2016] | 2017 [YR2017] | 2018 [YR2018] | 2019 [YR2019] | 2020 [YR2020] | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | GDP (constant LCU) | 1.594e+13 | 1.678e+13 | 1.745e+13 | 1.847e+13 | 1.894e+13 | 1.912e+13 | 2.127e+13 | 22039895161000 | 2.283e+13 | ... | 125823838388000 | 1.325e+14 | 137876215768070 | 1.403e+14 | 1.435e+14 | 1.460e+14 | 147736095622350 | 153570668110240 | 155189982580250 | NaN |
1 | Final consumption expenditure (constant LCU) | 1.336e+13 | 1.402e+13 | 1.461e+13 | 1.515e+13 | 1.514e+13 | 1.529e+13 | 1.700e+13 | 17541684272603 | 1.823e+13 | ... | 93785024589700 | 9.909e+13 | 103336792108318 | 1.063e+14 | 1.090e+14 | 1.128e+14 | 116887472687087 | 121369080634627 | 122348544930031 | NaN |
2 rows × 62 columns
2. Clean up the years and make numbers
Cleaning up the dates is bit complicated but not impossible. I will slice each column name and take the first four characters. Then convert those first four characters to an int
. The first column name is not a date, so I need to skip that.
temp = [wdi.columns[0]] + [int(c[0:4]) for c in wdi.columns[1:]]
temp
wdi.columns = temp
wdi.head()
Series Name | 1960 | 1961 | 1962 | 1963 | 1964 | 1965 | 1966 | 1967 | 1968 | ... | 2011 | 2012 | 2013 | 2014 | 2015 | 2016 | 2017 | 2018 | 2019 | 2020 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | GDP (constant LCU) | 1.594e+13 | 1.678e+13 | 1.745e+13 | 1.847e+13 | 1.894e+13 | 1.912e+13 | 2.127e+13 | 22039895161000 | 2.283e+13 | ... | 125823838388000 | 1.325e+14 | 137876215768070 | 1.403e+14 | 1.435e+14 | 1.460e+14 | 147736095622350 | 153570668110240 | 155189982580250 | NaN |
1 | Final consumption expenditure (constant LCU) | 1.336e+13 | 1.402e+13 | 1.461e+13 | 1.515e+13 | 1.514e+13 | 1.529e+13 | 1.700e+13 | 17541684272603 | 1.823e+13 | ... | 93785024589700 | 9.909e+13 | 103336792108318 | 1.063e+14 | 1.090e+14 | 1.128e+14 | 116887472687087 | 121369080634627 | 122348544930031 | NaN |
2 rows × 62 columns
3. Years in the index
Looking better! Let's set the 'Series Name' as the index and then transpose the DataFrame. We will learn more about reshaping data in the future.
wdi = wdi.set_index('Series Name')
wdi
wdi = wdi.transpose()
wdi.head()
Series Name | GDP (constant LCU) | Final consumption expenditure (constant LCU) |
---|---|---|
1960 | 1.594e+13 | 1.336e+13 |
1961 | 1.678e+13 | 1.402e+13 |
1962 | 1.745e+13 | 1.461e+13 |
1963 | 1.847e+13 | 1.515e+13 |
1964 | 1.894e+13 | 1.514e+13 |
wdi.columns = ['gdp_real', 'cons_real']
wdi.head()
gdp_real | cons_real | |
---|---|---|
1960 | 1.594e+13 | 1.336e+13 |
1961 | 1.678e+13 | 1.402e+13 |
1962 | 1.745e+13 | 1.461e+13 |
1963 | 1.847e+13 | 1.515e+13 |
1964 | 1.894e+13 | 1.514e+13 |
Lastly, let's compute the growth rates of real gdp and consumption.
wdi_gr = wdi.pct_change()*100
wdi_gr.head()
gdp_real | cons_real | |
---|---|---|
1960 | NaN | NaN |
1961 | 5.245 | 4.953 |
1962 | 4.027 | 4.237 |
1963 | 5.840 | 3.682 |
1964 | 2.557 | -0.113 |
sns.regplot( )
Our go-to scatterplot in matplotlib does not offer an easy way to add a simple line of best fit. We can separately estimate the regression and then plot the fitted values, but seaborn provides a simple way to get that same look.
The regplot( )
function (docs) adds a fitted regression and confidence interval to a scatter.
my_fig, my_ax = plt.subplots(figsize=(10,5))
sns.regplot(x='gdp_real', # column to put on x axis
y='cons_real', # column to put on y axis
data=wdi_gr, # the data
ax = my_ax, # an axis object
color = 'black',
ci = 0) # confidence interval, 0 supresses it
# Easier than matplotlib!
sns.despine(ax = my_ax)
# Since this is all in a matplotlib axis/figure, our usual labeling applies.
my_ax.set_title('Consumption and output growth in Chile')
my_ax.set_ylabel('real consumption growth (%)')
my_ax.set_xlabel('real output growth (%)')
plt.show()
Looking good! The default regression spec is OLS, but you can specify a more complicated model. Interestingly, you cannot recover the coefficients (slope, intercept) of the estimated regression. The person who developed seaborn is apparently quite adamant about this.
Notice the different syntax seaborn exposes.
sns.regplot(x, y, data, ax)
We do not pass the columns of data directly, like we would with matplotlib. Instead, we pass the DataFrame and the column names of the variables we want to plot.
We can pass regplot()
an axis object, ax = my_ax
in the code above, to attach the plot to a specific axis. If we omit the ax argument, regplot()
will create an axis object for us.
Confidence interval
Let's specify a 95 percent confidence interval.
my_fig, my_ax = plt.subplots(figsize=(10,5))
sns.regplot(x='gdp_real', # column to put on x axis
y='cons_real', # column to put on y axis
data=wdi_gr, # the data
ax = my_ax, # an axis object
color = 'red',
ci = 95) # confidence interval, 0 supresses it
# Easier than matplotlib!
sns.despine(ax = my_ax)
# Since this is all in a matplotlib axis/figure, our usual labeling applies.
my_ax.set_title('Consumption and output growth in Chile')
my_ax.set_ylabel('real consumption growth (%)')
my_ax.set_xlabel('real output growth (%)')
plt.show()
sns.jointplot( )
jointplot( )
adds the marginal distributions of the plotted variables to the axis of a regplot (docs).
This may be useful for visualizing the marginal distributions, but is it important that your reader see the marginal distributions? Is it telling them something important? Potentially, but remember, just because you can do something, doesn't mean you always should.
I occasionally use this plot when I am doing preliminary data exploration.
# Rather than call plt.subplots, let seaborn create the fig and axes.
# h is the axis-like object.
h = sns.jointplot(x='gdp_real',
y='cons_real',
kind='reg', # specify a regplot in the main plot area
data=wdi_gr,
ci=95,
color = 'black',
height = 5 # we can still control the figure size
)
h.set_axis_labels('real gdp growth', 'real consumption growth')
# what is h?
print('h is a', type(h))
plt.show()
h is a <class 'seaborn.axisgrid.JointGrid'>
Notice that I did not start my figure by creating fig and axes objects. Instead, I let seaborn create the fig and axes for me. The return from sns.jointplot()
is a JointGrid object created by seaborn. This is a more complicated figure, so it needs more complicated axes ojects.
Practice
The OECD has a project studying broadband internet coverage across countries. It tracks data on numbers of subscribers, speed, and prices.
- Load 'broadband_size.xlsx'. It contains data on broadband accounts per 100 people, GDP per capita, and population (in thousands) for several countries. Are all your variables okay?
- Give the columns some reasonable names.
broad = pd.read_excel('broadband_size.xlsx', thousands=',')
broad.columns = ['cty', 'broad_pen', 'gdp_cap', 'pop']
broad.head()
cty | broad_pen | gdp_cap | pop | |
---|---|---|---|---|
0 | Australia | 31.796 | 50588.149 | 24451 |
1 | Austria | 28.543 | 52467.527 | 8735 |
2 | Belgium | 38.588 | 47941.661 | 11429 |
3 | Canada | 37.847 | 46704.892 | 36624 |
4 | Chile | 16.515 | 24012.915 | 18055 |
- Create a
.regplot()
with broadband penetration on the y axis and GDP per capita on the x axis. Add the 95 percent confidence interval. Apply the principles of graphical excellence to your figure.
fig, ax = plt.subplots(figsize=(12,7))
sns.regplot(x='gdp_cap', y='broad_pen', data=broad, # the data
ax = ax, # an axis object
color = 'blue', # make it blue
ci = 95) # confidence interval: pass it the percent
sns.despine(ax = ax)
ax.set_title('Broadband penetration and income')
ax.set_ylabel('broadband subscribers per 100 people')
ax.set_xlabel('GDP per capita')
plt.show()
- The relationship doesn't look very linear to me. Replot your solution from 3. but try adding the
logx=True
option to regplot to regress y = log(x). As always, consult the documentation if you need help.
fig, ax = plt.subplots(figsize=(12,7))
sns.regplot(x='gdp_cap', y='broad_pen', data=broad, # the data
ax = ax, # an axis object
color = 'blue', # make it blue
ci = 95, # confidence interval: pass it the percent
logx = True) # fit the line to y = log(x)
sns.despine(ax = ax)
ax.set_title('Broadband penetration and income')
ax.set_ylabel('broadband subscribers per 100 people')
ax.set_xlabel('GDP per capita')
plt.show()
Bubble plot (and passing keywords)
A bubble plot is a scatter plot in which the size of the data markers (usually a circle) varies with a third variable.
We can actually make these plots in matplotlib. The syntax is
ax.plot(x, y, s)
where s
is the variable corresponding to marker size. Since seaborn is built on top of matplotlib, we can pass scatter keyword arguments to .regplot( )
and these get passed through to the underlying scatter.
If we pass a single number to s
it changes the size of all the bubbles. If we pass it a Series of data, then each bubble gets scaled according to its value in the series.
The syntax for the option is scatter_kws={'s': data_var}
. This sets the s
argument of scatter to data_var
.
fig, ax = plt.subplots(figsize=(10,5))
sns.regplot(x='gdp_cap', y='broad_pen', data=broad, # the data
ax = ax, # an axis object
scatter_kws={'s': broad['pop']/1000}, # make the marker proportional to population
#scatter_kws={'s': 25},
color = 'blue', # make it blue
ci = 95, # confidence interval: pass it the percent
logx = True) # semi-log regression
# We need to let the reader know what the bubble sizes represent.
ax.text(50000, 20, 'Marker size proportional to population size')
sns.despine(ax = ax)
ax.set_title('Broadband penetration and income')
ax.set_ylabel('broadband subscribers per 100 people')
ax.set_xlabel('GDP per capita')
plt.show()
Notice that I have scaled population by 1000. The issue is that s
is interpreted as points^2 (points squared) [docs]. The idea is that the area of the marker increases proportional to the square of the width. There is a good discussion of it at stack overflow.
If you try to use s
and your whole figure turns the color of your marker, you probably need to scale your measure for s
.
Another example of the scatter_kws useage is to customize the scatter colors and alpha.
fig, ax = plt.subplots(figsize=(10,5))
# To keep the call to regplot from getting out of control, I define the scatter keywords dict here.
my_kws={'s': broad['pop']/1000, 'alpha':0.25, 'color':'black'}
sns.regplot(x='gdp_cap', y='broad_pen', data=broad, # the data
ax = ax, # an axis object
scatter_kws = my_kws, # pass parameters to scatter
color = 'blue', # make it blue
ci = 95, # confidence interval: pass it the percent
logx = True) # semi-log regression
# We need to let the reader know what the bubble sizes represent.
ax.text(50000, 20, 'Marker size proportional to population size')
sns.despine(ax = ax)
ax.set_title('Broadband penetration and income')
ax.set_ylabel('broadband subscribers per 100 people')
ax.set_xlabel('GDP per capita')
plt.show()
Facet plots
Facet plots are grids of plots with the same x- and y-axes. Each plot in the grid is a different subset of the sample. Seaborn gives us simple way to make these plots.
We often use facet plots in initial exploratory analysis. If we do not know what we are looking for, a facet plot is a good way to start "eye-balling" relationships. Once we have some ideas, we can narrow down our focus and use more precise tools. In general, we do not include large grids of figures in our finished analysis. They contain too much unnecessary information.
Load the file 'auto_data.dta' which contains data on automobile characteristics in the European market. These data are from Miravete, Moral, and Thurk.
df = pd.read_stata('auto_data.dta')
df.head(3)
CODE | ORIG | FIRM_ID | FIRM | BRAND | MODEL | YEAR | PRICE | QUANTITY | HP | LENGTH | WIDTH | SIZE | WEIGHT | FUEL | MPG | FUELPRICE | SEGMENT | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 10102 | 1 | 4.0 | Fiat | Alfa Romeo | ALFA 164 | 1995 | 27.581 | 179.0 | 0.048 | 179.134 | 69.291 | 1.241 | 3039.647 | 0 | 40.554 | 0.71 | 4 |
1 | 10103 | 1 | 4.0 | Fiat | Alfa Romeo | ALFA 145 | 1995 | 20.202 | 4934.0 | 0.041 | 161.024 | 67.323 | 1.084 | 2511.013 | 0 | 38.560 | 0.71 | 2 |
2 | 10104 | 1 | 4.0 | Fiat | Alfa Romeo | ALFA 155 | 1995 | 23.651 | 1017.0 | 0.048 | 174.803 | 67.039 | 1.172 | 2671.586 | 0 | 35.107 | 0.71 | 3 |
# Recode the FUEL variable so I can easily understand it.
df['FUEL'] = df['FUEL'].replace({0:'gasoline', 1:'diesel'})
df.sample(5)
CODE | ORIG | FIRM_ID | FIRM | BRAND | MODEL | YEAR | PRICE | QUANTITY | HP | LENGTH | WIDTH | SIZE | WEIGHT | FUEL | MPG | FUELPRICE | SEGMENT | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
955 | 20206 | 1 | 19.0 | Volkswagen | Audi | A8 | 1999 | 33.303 | 2.0 | 0.042 | 198.031 | 74.016 | 1.466 | 3572.247 | diesel | 44.955 | 0.651 | 4 |
54 | 11602 | 2 | 10.0 | Mazda | Mazda | 626 | 1995 | 24.940 | 572.0 | 0.046 | 185.039 | 68.898 | 1.275 | 2544.053 | gasoline | 38.560 | 0.710 | 3 |
313 | 20105 | 1 | 4.0 | Fiat | Alfa Romeo | ALFA 146 | 1996 | 27.266 | 189.0 | 0.033 | 167.323 | 67.323 | 1.126 | 2742.291 | diesel | 47.043 | 0.589 | 2 |
823 | 10204 | 1 | 19.0 | Volkswagen | Audi | A4 | 1999 | 28.356 | 55.0 | 0.046 | 175.984 | 68.110 | 1.199 | 2698.238 | gasoline | 40.554 | 0.849 | 3 |
75 | 12005 | 1 | 6.0 | GM | Opel | VECTRA | 1995 | 24.507 | 683.0 | 0.049 | 174.913 | 66.929 | 1.171 | 2796.035 | gasoline | 39.624 | 0.710 | 3 |
Looking at the data, we see that a unit of observation is a model at a point in time. We see prices and quantities sold and characteristics about the model. Let's cut the data down to VW and try some plots.
vw = df[df['FIRM']=='Volkswagen']
vw.sample(8)
CODE | ORIG | FIRM_ID | FIRM | BRAND | MODEL | YEAR | PRICE | QUANTITY | HP | LENGTH | WIDTH | SIZE | WEIGHT | FUEL | MPG | FUELPRICE | SEGMENT | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1254 | 22803 | 1 | 19.0 | Volkswagen | Skoda | FELICIA | 2000 | 12.138 | 23945.0 | 0.029 | 151.575 | 64.173 | 0.973 | 2191.630 | diesel | 54.701 | 0.828 | 1 |
101 | 12702 | 1 | 19.0 | Volkswagen | Seat | MARBELLA | 1995 | 6.295 | 566831.0 | 0.027 | 137.008 | 59.055 | 0.809 | 1497.797 | gasoline | 50.046 | 0.710 | 1 |
7 | 10204 | 1 | 19.0 | Volkswagen | Audi | A4 | 1995 | 38.669 | 4.0 | 0.046 | 175.984 | 68.110 | 1.199 | 2698.238 | gasoline | 40.554 | 0.710 | 3 |
709 | 12704 | 1 | 19.0 | Volkswagen | Seat | CORDOBA | 1998 | 16.832 | 7877.0 | 0.033 | 162.992 | 64.567 | 1.052 | 2257.709 | gasoline | 39.867 | 0.790 | 2 |
374 | 22703 | 1 | 19.0 | Volkswagen | Seat | TOLEDO | 1996 | 22.220 | 1032.0 | 0.034 | 170.079 | 65.354 | 1.112 | 2314.097 | diesel | 51.711 | 0.589 | 3 |
1048 | 10205 | 1 | 19.0 | Volkswagen | Audi | A6 | 2000 | 49.093 | 1.0 | 0.054 | 188.583 | 71.260 | 1.344 | 3083.700 | gasoline | 35.107 | 1.042 | 4 |
1256 | 22805 | 1 | 19.0 | Volkswagen | Skoda | FABIA | 2000 | 19.523 | 1758.0 | 0.026 | 155.358 | 64.516 | 1.002 | 2457.269 | diesel | 58.252 | 0.828 | 1 |
115 | 13203 | 1 | 19.0 | Volkswagen | Volkswagen | POLO | 1995 | 13.223 | 55956.0 | 0.026 | 146.063 | 64.961 | 0.949 | 2125.463 | gasoline | 46.943 | 0.710 | 1 |
Q: How are vehicle weight and fuel efficiency related? Does it vary by fuel type? Does it vary by brand?
- Volkswagen has four brands during this period: Audi, Seat, Skoda, and Volkswagen.
- There are two fuel types: gasoline and diesel.
Let's make a grid of plots where the rows are the brands and columns are the fuel types. This is a 4x2 grid.
In each plot, we will scatter weight vs. mpg.
g = sns.FacetGrid(vw, row='BRAND', col='FUEL')
g.map(plt.scatter, 'WEIGHT', 'MPG', color='blue')
plt.show()
What is a point in a plot? It is a model-year.
What do we see?
- Diesel vehicles tend to get higher mpg
- Within a fuel type, the range of mpg are similar
- Weight and mpg are usually negatively correlated (Skoda diesel?)
- There is heterogeneity in the number of models per brand
Facet plot syntax
We first create the grid using FacetGrid()
. We specify which DataFrame we are plotting and which variables we want for the rows and columns. These variables should be categorical and should have relatively few potential values. Otherwise, the grid would get very large and it would be hard to interpret.
g = sns.FacetGrid(vw, row='BRAND', col='FUEL')
Next, we map a plot type to the grid using map()
. We can make many types of plots. In this case we have used the scatter()
plot from matplotlib. Notice the plt
that precedes the scatter
. We can also pass any keyword arguments that plt.scatter
accepts.
g.map(plt.scatter, 'WEIGHT', 'MPG', color='blue')
Now lets try a different plot type, regplot()
from Seaborn. It's easier to see the relationship between weight and mpg.
g = sns.FacetGrid(vw, row='BRAND', col='FUEL', height=4)
g.map(sns.regplot, 'WEIGHT', 'MPG', color='red', ci=90)
plt.show()
Q: Are more powerful cars more expensive? Does it depend on fuel type? Brand?
We are not limited to just one type of data in each plot. We can use color to differentiate further. In the next figure we add Ford, PSA, and Fiat to the firms in our DataFrame. Each firm has several brands and each brand has several models.
- Columns are still fuel type
- Rows are now firms (VW, Ford, PSA, Fiat)
- Hue (color) is brand (Ford's Volvo; Fiat's Alfa Romeo, etc.)
g = sns.FacetGrid(to_plot, hue='BRAND', col='FUEL', row='FIRM')
In each plot we have * Price vs HP * I'm using a scatter plot
g.map(plt.scatter, 'PRICE', 'HP')
# Create a dataframe with just these firms
firms = ['Ford', 'PSA', 'Volkswagen', 'Fiat']
to_plot = df[df['FIRM'].isin(firms)]
g = sns.FacetGrid(to_plot, hue='BRAND', col='FUEL', row='FIRM')
g.map(plt.scatter, 'PRICE', 'HP')
g.add_legend()
plt.show()
Practice: Facet Plots
- Q: How is size related to price? Does it differ by firm? By brand? By fuel type?
Use a facet plot to explore these questions. Restrict the DataFrame to include only Ford, PSA, Volkswagen, and Fiat.
# The .isin() saves us from syntax like
# to_plot = df[(df['FIRM']=='Ford') | (df['FIRM'] == 'PSA') | (df['FIRM']==Volkswagen) | (df['FIRM']=='Fiat')]
firms = ['Ford', 'PSA', 'Volkswagen', 'Fiat']
to_plot = df[df['FIRM'].isin(firms)]
g = sns.FacetGrid(to_plot, row='FUEL', col='FIRM', hue='BRAND')
g.map(plt.scatter, 'SIZE', 'PRICE')
g.add_legend()
plt.show()
- Let's explore a related concept, the
pairplot
[docs]. Try
g=sns.pairplot(df, vars=['PRICE', 'HP', 'WEIGHT'])
What does pairplot
do? Why do we only need to look at the upper or lower triangle of the figure? What is on the diagonal?
g=sns.pairplot(df, vars=['PRICE', 'HP', 'WEIGHT', 'MPG'])
# pairplot plots scatter plots for each possible pair of variables.
# The figure is symmetric, so any plot in the top triangle is also in
# bottom triangle, but with the axes reversed.
# The diagonal is the histogram of the variable.
- How do these relationships differ by fuel type? Use the 'hue' option (and the documentation).
g=sns.pairplot(df, vars=['PRICE', 'HP', 'WEIGHT', 'MPG'], hue='FUEL')
# It is clear that gasoline has a right-shifted HP distribution and a left-shifted MPG distribution.
# For diesel, HP doesn't look correlated with weight or mpg. For gas HP and MPG look negatively correlated.
# The weight-mpg correlation looks to have a similar slope and a different intercept by fuel type.