Plot a grid of plots in python by iterating over the subplots

In this article, we will make a grid of plots in python by iterating over the subplot axes and columns of a pandas dataframe.

Python has a versatile plotting framework in Matplotlib but the documentation seems extremely poor (or I was not able to find the right docs). It took me a fair amount of time to figure out how to send plots of columns of dataframe to individual subplots while rotating the xlabels for each subplot.

Usage

Plotting subplots in Matplotlib begins by using the plt.subplots() statement.

import pandas as pd
import matplotlib.pyplot as plt


fig, axs = plt.subplots(nrows=2, ncols=2)

We can omit the nrows and ncols args but I kept it for effect. This statement generates a grid of 2×2 subplots and returns the overall figure (the object which contains all plots inside it) and the individual subplots as a tuple of subplots. The subplots can be accessed using axs[0,0], axs[0,1], axs[1,0], and axs[1,1]. Or they can be unpacked during the assignment as follows.

import pandas as pd
import matplotlib.pyplot as plt


fig, ((ax1, ax2),(ax3, ax4)) = plt.subplots(nrows=2, ncols=2)

When we have 1 row and 4 columns instead of 2 rows and 2 columns it has to be unpacked as follows.

import pandas as pd
import matplotlib.pyplot as plt


fig, ((ax1, ax2, ax3, ax4)) = plt.subplots(nrows=1, ncols=4)

Flattening the grid of subplots

We, however, do not want to unpack individually. Instead, we would like to flatten the tuple of subplots and iterate over them rather than assigning each subplot to a variable. The tuple is flattened by the flatten() command.

axs.flatten()

We identify 4 columns of a dataframe we want to plot and save the column names in a list that we can iterate over. We flatten the subplots and generate an iterator or we can convert the iterator to a list and then pack it (zip) with the column names.

import pandas as pd
import matplotlib.pyplot as plt


profiles_file = 'data.csv'
df = pd.read_csv(profiles_file)

cols_to_plot = ['age', 'drinking', 'exercise', 'smoking']

fig, axs = plt.subplots(nrows=2, ncols=2)
fig.set_size_inches(20, 10)
fig.subplots_adjust(wspace=0.2)
fig.subplots_adjust(hspace=0.5)

for col, ax in zip(cols_to_plot, axs.flatten()):
    dftemp = df[col].value_counts()
    ax.bar(dftemp.index, list(dftemp))
    ax.set_title(col)
    ax.tick_params(axis='x', labelrotation=30)

plt.show()

As we iterate over each subplot axes, and the column names which are zipped with it, we plot each subplot with the ax.plot() command and we have to supply the x and y values manually. I tried plotting with pandas plot df.plot.bar() and assigning the returned object to the ax. It doesn’t work. The x values for the ax.plot() are the dataframe index (df.index) and y values are the values in the dataframe column (which needs to be converted to a list to as ax.plot() does not accept pd.Series).

Rotate x-axis of subplots

The x-axis for each subplot is rotated using

ax.tick_params(axis='x', labelrotation=30)

 

Leave a comment