Plot a grid of plots in python by iterating over the subplots

In this article, we will make a grid of plots in python by iterating over the subplot axes and columns of a pandas dataframe.

Python has a versatile plotting framework in Matplotlib but the documentation seems extremely poor (or I was not able to find the right docs). It took me a fair amount of time to figure out how to send plots of columns of dataframe to individual subplots while rotating the xlabels for each subplot.

Usage

Plotting subplots in Matplotlib begins by using the plt.subplots() statement.

import pandas as pd
import matplotlib.pyplot as plt


fig, axs = plt.subplots(nrows=2, ncols=2)

We can omit the nrows and ncols args but I kept it for effect. This statement generates a grid of 2×2 subplots and returns the overall figure (the object which contains all plots inside it) and the individual subplots as a tuple of subplots. The subplots can be accessed using axs[0,0], axs[0,1], axs[1,0], and axs[1,1]. Or they can be unpacked during the assignment as follows.

import pandas as pd
import matplotlib.pyplot as plt


fig, ((ax1, ax2),(ax3, ax4)) = plt.subplots(nrows=2, ncols=2)

When we have 1 row and 4 columns instead of 2 rows and 2 columns it has to be unpacked as follows.

import pandas as pd
import matplotlib.pyplot as plt


fig, ((ax1, ax2, ax3, ax4)) = plt.subplots(nrows=1, ncols=4)

Flattening the grid of subplots

We, however, do not want to unpack individually. Instead, we would like to flatten the tuple of subplots and iterate over them rather than assigning each subplot to a variable. The tuple is flattened by the flatten() command.

axs.flatten()

We identify 4 columns of a dataframe we want to plot and save the column names in a list that we can iterate over. We flatten the subplots and generate an iterator or we can convert the iterator to a list and then pack it (zip) with the column names.

import pandas as pd
import matplotlib.pyplot as plt


profiles_file = 'data.csv'
df = pd.read_csv(profiles_file)

cols_to_plot = ['age', 'drinking', 'exercise', 'smoking']

fig, axs = plt.subplots(nrows=2, ncols=2)
fig.set_size_inches(20, 10)
fig.subplots_adjust(wspace=0.2)
fig.subplots_adjust(hspace=0.5)

for col, ax in zip(cols_to_plot, axs.flatten()):
    dftemp = df[col].value_counts()
    ax.bar(dftemp.index, list(dftemp))
    ax.set_title(col)
    ax.tick_params(axis='x', labelrotation=30)

plt.show()

As we iterate over each subplot axes, and the column names which are zipped with it, we plot each subplot with the ax.plot() command and we have to supply the x and y values manually. I tried plotting with pandas plot df.plot.bar() and assigning the returned object to the ax. It doesn’t work. The x values for the ax.plot() are the dataframe index (df.index) and y values are the values in the dataframe column (which needs to be converted to a list to as ax.plot() does not accept pd.Series).

Rotate x-axis of subplots

The x-axis for each subplot is rotated using

ax.tick_params(axis='x', labelrotation=30)

 

Use pandas to convert a date to datetime format

Importing dates from a CSV file is always a hassle. With myriads of DateTime formats possible, we will need to write extensive amounts of code to accommodate al possible DateTime formats or put restrictions on the contents of the CSV file. We don’t want to do either. Instead of hard-coding commands like

map(datetime.strftime(string, “%m/%d/%Y))

into our codes, we can use pandas to convert the dates for us. Pandas has the capability to convert an entire column of dates in string format to DateTime format. We just need to be careful when importing just dates and not DateTime objects(strings). Pandas usually converts to DateTime objects. If we are just importing dates then the time components are undesirable. We will need to strip off the time part using .date() at the end. So instead of

pd.to_datetime(date)

we will need to use

pd.to_datetime(date).date()

An example script illustrates this procedure.

def dateformat(self, date):
    # use pandas to convert a date to datetime format
    # extract just the date since pandas returns the date as Timestamp object
    # repack the date as datetime using datetime.datetime.combine() with time = 00:00

    date = dt.datetime.combine(pd.to_datetime(date).date(), 
                               dt.datetime.min.time())
    return date

 

Convert column to a float from str dropping non-numeric strings

Let us say we have the following dataframe

df[‘Amount $’]

0.07
1.154
2.596
X-Links
Amount $
0.102

And we want to convert all numbers to float and drop the non-numeric rows. isnumeric() will not work since this data is all str dtype. The only option is to write a small function which tries to convert a string to a float. If it fails it returns FALSE. If this function is mapped to the entire column using a lambda function then it will return a boolean list(series) where TRUE means float and FALSE means non-float. When this is used as a boolean mask on the dataframe, it will filter out the non-numeric rows.

def tryfloat(self, f):
       try:
           float(f)
           return True

       except ValueError:
           return False

df[ df['Amount $'].apply(lambda x: tryfloat(x)) ]

 

Result is this table

0.07
1.154
2.596
0.102

Dataframe manipulation with pandas

Merge databases

db1 = pd.DataFrame({'Name':['Jones','Will','Rory','Bryce','Hok'],
 'job_id':[2,5,3,7,2]}, index=[1,2,3,4,5])



db2 = pd.DataFrame({'Name':['CEO','Chairman','Vice-Chairman',
'Senior Engineer'], 'job_id':[5,1,2,3]}, index=[1,2,3,4])

df = pd.merge(db1,db2,on='job_id')
Name_x  job_id        Name_y
0  Jones       2 Vice-Chairman
1    Hok     2 Vice-Chairman
2   Will       5    CEO
3   Rory       3 Senior Engineer

merge() automatically removes the rows which contain null placeholder values similar to inner join and renames the columns appropriately.

https://pandas.pydata.org/pandas-docs/stable/merging.html

Extracting rows from a dataframe by row number using iloc

>>> df.iloc[2]
Name_x    Will
job_id       5
Name_y     CEO
Name: 2, dtype: object

Extracting rows which match a string value

Syntax: df[ ( df[‘col’] == “some value” )  ]

(hpi[‘RegionName’] == “Mesa”) generates a Boolean set which can then be used to extract the rows which are True from hpi[]. Note that the ( ) are crucial to the operation of converting it to a set.

# select all rows where the RegionName is "Mesa"
mesadataall = hpi[ (hpi['RegionName'] == "Mesa")  ]

Cleaning databases using replace()

# clean data with sed like REGEX
# remove all (2014) references
moddata.replace(r" \(.*\)", "", inplace=True, regex=True) 
# replace the word unavailable by 0 
moddata.replace(r"unavailable", "0", inplace=True, regex=True) 

 

These REGEX clean the data by removing non-numeric data and replacing them by 0.

Web scrape tables from website using pandas

data = pd.read_html(
'https://en.wikipedia.org/wiki/List_of_countries_by_firearm-related_death_rate')
# entire HTML is imported as a list
# the table in is the fourth element of the list

df = data[4]

 

to be continued …