"""Convenience functions for data plotting
- chart types: date axis, time axis, confidence bands, bar, hist, scatter
- regression diagnostics
Copyright 2023, Terence Lim
MIT License
"""
from typing import Iterable, Mapping, List, Any, Tuple, Callable, Dict
import numpy as np
import scipy
from datetime import datetime
import pandas as pd
from pandas import DataFrame, Series, Timestamp
from pandas.api.types import is_list_like
import matplotlib.pyplot as plt
from matplotlib import dates as mdates
from matplotlib import colors, cm
from matplotlib.lines import Line2D
import seaborn as sns
import statsmodels.api as sm
from statsmodels.graphics.gofplots import ProbPlot
from pandas.plotting import register_matplotlib_converters
register_matplotlib_converters() # for date formatting in plots
#plt.style.use('seaborn-dark') # plt.style.use('ggplot')
# fig.subplots_adjust(hspace=0.3) # spacing between subplots
# plt.locator_params(axis='x', nbins=len(delist)/2) # number f xtick labels
# get_ipython().magic(u"%matplotlib qt")
# fig, axes = plt.subplots(ncols=3, nrows=5, layout='constrained')
import matplotlib as mpl
[docs]def ColorMap(n, colormap='brg'):
"""Return sampled colors from spectrum of a color map
Args:
n : number of discrete colors
colormap : matplotlib colormaps, e.g. tab20, viridis, Blues, cool
Returns:
n x 4 2D-array of RGBA values
"""
return mpl.colormaps[colormap](np.linspace(0, 1, n))
[docs]def subplots(nfigs: int = 1, ncols: int = 1, nrows: int = 1,
squeeze: bool = False, **kwargs) -> List[plt.Axes]:
"""Wrapper over plt.subplots to create subplots across multiple figures
Args:
nfigs: Number of figures
ncols: Number of columns per figure
nrows: Number of rows per figure
squeeze: ignored, will always return 1D flattened list of axes unless only 1
**kwargs: other keyword arguments passed on to plt.subplots
Returns:
Tuple of (List of figures, List of axes)
"""
figs, axes = [], []
for ifig in range(nfigs):
fig, axs = plt.subplots(ncols=ncols, nrows=nrows, squeeze=False, **kwargs)
figs.append(fig)
axes.extend([a for ax in axs for a in ax])
return (figs, axes)
[docs]def set_xticks(ax: plt.Axes, nbins: int = 0, nskip: int = 0, **kwargs):
"""Set locations for xticks
Args:
ax: Matplotlib axes object from plt.subplots() or plt.gca()
nbins: Number of bins for major ticks, or
nskip: Number of ticks to skip (ignore if nbins > 0)
**kwargs: Arguments for tick_params(), e.g. labelsize, colors, rotation
"""
xticks = ax.get_xticks()
if nbins:
ax.set_xticks(ticks=xticks[::len(xticks) // nbins])
else:
ax.set_xticks(ticks=xticks[::(nskip+1)])
ax.tick_params(axis='x', **kwargs)
ax.margins(x=0) # set tight margins
##############################
#
# 1. Chart types: hist, scatter, bar, date, time, bands
#
##############################
[docs]def plot_date(y1: DataFrame, y2: DataFrame | None = None, ax: Any = None,
xmin: int = 0, xmax: int = 99999999, cn: int = 0,
fontsize: int = 12, rescale: bool = False, yscale: bool = False,
ls: str = '-', marker: str | None = '', ms: int | None = None,
hlines: List[float] = [], vlines: List[int] = [], nbins: int = 0,
vspans: List[Tuple[int, int]] = [], xlabel: str = '',
points: DataFrame | Series | None = None, rotation: float = 0,
title: str = "", ylabel1: str = "", ylabel2: str = "",
legend1: List[str] = [], legend2: List[str] = [],
loc1: str = 'upper left', loc2: str = 'upper right'):
"""Line plot with int date on x-axis, and primary and secondary y frames
Args:
y1: Plot on primary y-axis
y2: Plot on secondary y-axis
ax: Matplotlib axes object from plt.subplots() or plt.gca()
cn: Starting CN color to cycle through
marker: Marker style, None to cycle (default '' means no marker)
xmin: Minimum of x-axis date range
xmax: Maximum of x-axis date range (default is auto)
nbins: Number of bins for xticks placement (default 0 is auto)
rotation: Rotation of x-axis ticks
hlines: Y-axis points where to place horizontal lines
vlines: X-axis points where to place vertical lines
vspans: Vertical regions to highlight
xlabel: X-axis label
ylabel1, ylabel2: Y-axis labels
fontsize: Base font size
points: Points and labels to annotate
title: Main title
legend1, legend2: Lists of legend labels
loc1, loc2: Locations to place legends
ls: Linestyle
ms: Marker size
"""
markers = "os*.x+D8Xv41<2>3os*.x+D8Xv41<2>3os*.x+D8Xv41<2>3"
ax = ax or plt.gca()
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y%m%d'))
ax.xaxis.set_tick_params(rotation=rotation, labelsize=fontsize)
if y1 is not None:
y1 = DataFrame(y1)
y1 = y1.loc[(y1.index >= xmin) & (y1.index <= xmax)]
base = y1.loc[max(y1.notna().idxmax()),:] if rescale else 1
#sns.lineplot(x = pd.to_datetime(y1.index[f], format='%Y%m%d'),
#y = y1.loc[f], ax=ax)
for ci, c in enumerate(y1.columns):
f = y1.loc[:,c].notnull().values
ax.plot(pd.to_datetime(y1.index[f], format='%Y%m%d'),
y1.loc[f,c] / (base[c] if rescale else 1),
marker=markers[ci] if marker is None else marker,
ms=ms or fontsize,
linestyle=ls,
color=f'C{ci+cn}')
if points is not None:
ax.scatter(pd.to_datetime(points.index, format='%Y%m%d'),
points,
marker='o',
ms=ms or fontsize,
color='r')
if len(y1.columns) > 1 or legend1:
ax.set_ylabel('')
ax.legend(legend1 or y1.columns,
fontsize=fontsize,
loc=loc1)
if ylabel1:
ax.set_ylabel(ylabel1, fontsize=fontsize)
ax.set_xlabel(xlabel, fontsize=fontsize)
ax.yaxis.set_tick_params(labelsize=fontsize)
if y2 is not None:
y2 = DataFrame(y2)
y2 = y2.loc[(y2.index >= xmin) & (y2.index <= xmax)]
base = y2.loc[max(y2.notna().idxmax()),:] if rescale else 1
bx = ax.twinx()
for cj, c in enumerate(y2.columns):
g = y2.loc[:,c].notnull().values
bx.plot(pd.to_datetime(y2.index[g], format='%Y%m%d'),
y2.loc[g, c] / (base[c] if rescale else 1),
marker=markers[ci+cj] if marker is None else marker,
linestyle=ls,
ms=ms or fontsize,
color=f"C{ci+cj+cn+1}")
if yscale:
amin, amax = ax.get_ylim()
bmin, bmax = bx.get_ylim()
bx.set_ylim(min(amin, bmin), max(amax, bmax))
if len(y2.columns) > 1 or legend2:
bx.set_ylabel('')
bx.legend(legend2 or y2.columns,
fontsize=fontsize,
loc=loc2)
if ylabel2:
bx.set_ylabel(ylabel2, fontsize=fontsize+2)
bx.yaxis.set_tick_params(labelsize=fontsize)
for hline in hlines:
ax.axhline(hline,
linestyle='-.',
color='y')
for vline in vlines:
ax.axvline(pd.to_datetime(vline, format='%Y%m%d'),
ls='-.',
color='y')
for vspan in vspans:
ax.axvspan(*([pd.to_datetime(v, format='%Y%m%d') for v in vspan]),
alpha=0.5,
color='grey')
ax.set_title(title, fontsize=fontsize+4)
if nbins:
# ax.locator_params(axis='x', nbins=nbins) # numeric ticks
set_xticks(ax, nbins) # non-numeric ticks
#plt.tight_layout()
return ax
[docs]def plot_groupbar(df: DataFrame, ax: Any = None, labels: DataFrame | None = None):
"""Plot a grouped bar chart
Args:
df : DataFrame to plot
ax : Axis object
labels : optional DataFrame of annotations
"""
ax = ax or plt.gca()
x = np.arange(len(df)) # the x-axis locations
width = 1 / (len(df.columns) + 1) # the width of the bars
multiplier = 0
for label in df.columns:
y = df[label]
offset = width * multiplier
rects = ax.bar(x=x+offset, height=y, width=width, label=label)
ax.bar_label(rects, labels=labels[label], padding=3, rotation=0,
fontsize='x-small')
multiplier += 1
ax.set_xticks(x + width, df.index)
return ax
#
# TODO:
# - drop x, mean should be y (a series)
# - stderr is either list-like or a value (then multiply ones)
#
[docs]def plot_bands(mean: Series, stderr: Series, width: float = 1.96,
x: List[int] = [], ylabel: str = '', xlabel: str = '',
c: str = "b", loc: str = 'best', legend: List[str] = [],
ax: Any = None, fontsize: int = 10, title: str = '',
hline: List = [], vline: List = []):
"""Line plot a series with confidence bands
Args:
mean : Mean values to plot
stderr : Stderr values to plot confidence bands
width : Multipler on stderr for confidence bands
c : Color to fill bands
x : X-axis values
ylabel, xlabel : Axis labels
legend : List of legend labels
loc : Location to display legend
ax: Axis object
fontsize : Base font size
title : Main title string
hline : List of y-axis values to plot horizontal lines
vline : List of x-axis values to plot vertical lines
"""
ax = ax or plt.gca()
if not x:
x = np.arange(len(mean)) # x-axis is event day number
for line in hline:
ax.axhline(line, linestyle=':', color='g')
for line in vline:
ax.axvline(line, linestyle=':', color='g')
ax.plot(x, mean, ls='-', c=c)
ax.fill_between(x, mean-(width*np.array(stderr)),
mean+(width*np.array(stderr)), alpha=0.3, color=c)
if legend:
ax.legend(legend, loc=loc, fontsize=fontsize)
ax.set_title(title, fontsize=fontsize+4)
ax.set_ylabel(ylabel, fontsize=fontsize+2)
ax.set_xlabel(xlabel, fontsize=fontsize+2)
# plt.tight_layout()
return ax
[docs]def plot_scatter(x: Series, y: Series, labels: List = [], ax: Any = None,
xlabel: str = '', ylabel: str = '', c: Any = None,
cmap: Any = None, alpha: float = 0.75, edgecolor: Any = None,
s: float = 10, marker: str = 'o', title: str = '',
abline: bool | None = True, fontsize: int = 12):
"""Scatter plot, optionally with abline slope and point labels
Args:
x: Series to plot on horizontal axis
y: Series to plot on horizontal axis
labels: List of annotations for points
ax: Matplotlib axes object, from plt.subplots() or plt.gca()
xlabel: Horizontal axis label
ylabel: Vertical axis label
title: Title of plot
abline: To plot abline or 45-degree line. If none, do not plot slope
labels: List of 3-tuples (text, x, y) to annotate
alpha: transparency of scatter points
edgecolor: edge color of scatter points
marker: marker type of scatter points
s: Marker area size
cmap: Color map to use for scatter points
abline: True for fitted slope, False for 45-degree line, None is no plot
fontsize: Base font size
title: Main title string
"""
if ax is None:
ax = plt.gca()
ax.cla()
ax.clear()
if c is not None and cmap is not None:
cmin = min(c)
cmax = max(c)
norm = colors.Normalize(cmin - (cmax-cmin)/2, cmax)
c = cm.ScalarMappable(norm=norm, cmap=cmap).to_rgba(c)
cmap = None
cax = ax.scatter(x, y, marker=marker, s=s, c=c, alpha=alpha,
edgecolor=edgecolor, cmap=cmap)
#cmap=plt.cm.get_cmap('tab10', 3)
if abline is not None:
xmin, xmax, ymin, ymax = ax.axis()
if abline: # plot fitted slope
f = ~(np.isnan(x) | np.isnan(y))
slope, intercept = np.polyfit(list(x[f]), list(y[f]), 1)
y_pred = [slope * i + intercept for i in list(x[f])]
ax.plot(x[f], y_pred, 'g-')
else: # plot 45-degree line
bottom_left, top_right = min(xmin, ymin), max(xmax, ymax)
ax.plot([bottom_left, top_right], [bottom_left, top_right], 'g-')
xlabel = xlabel or (x.name if hasattr(x, 'name') else "")
ylabel = ylabel or (y.name if hasattr(y, 'name') else "")
ax.set_xlabel(xlabel, fontsize=fontsize)
ax.set_ylabel(ylabel, fontsize=fontsize)
if len(labels):
for t, xt, yt in zip(labels, x, y):
ax.text(xt * 1.01, yt * 1.01, t, fontsize=fontsize)
ax.set_title(title, fontsize=fontsize+4)
mfc = cax.get_fc()[0]
return Line2D([0], [0], marker=marker, mfc=mfc, ms=10, ls='', c=mfc)
[docs]def plot_hist(*args, kde: bool = True, hist: bool = False,
bins: List[float] = [],
pdf: Callable | List | Dict = scipy.stats.norm.pdf,
ax: Any = None, title: str = '', xlabel: str = '',
ylabel: str = 'density', fontsize: int = 12):
"""Histogram bar plot with a benchmark probability density
Args:
ax: Axis object
bins: List of bin values
hist: Plots histogram
kde: Plots kernel density curve
pdf: Benchmark probability density
ylabel, xlabel: Axis labels
fontsize: Base font size
title: Main title text
"""
ax = ax or plt.gca()
for arg in args:
frame = DataFrame(arg)
for col in frame.columns:
y = frame[col].notnull().values
sns.distplot(frame[col][y], kde=kde, hist=hist,
bins=bins, label=col, ax=ax)
if pdf:
if not types.is_list_like(pdf):
pdf = [pdf]
if isinstance(pdf, dict):
labels = list(pdf.keys())
pdf = list(pdf.values())
else:
labels = None
pdf = list(pdf)
bx = ax.twinx() if args else ax
bx.yaxis.set_tick_params(rotation=0, labelsize=fontsize)
x= np.linspace(*ax.get_xlim(), 100)
for i, p in enumerate(pdf):
bx.plot(x, p(x), label=labels[i] if labels else None,
color=f"C{len(args)+i}")
if labels:
bx.legend(labels, loc='center right')
ax.legend(loc='center left')
ax.xaxis.set_tick_params(rotation=0, labelsize=fontsize)
ax.yaxis.set_tick_params(rotation=0, labelsize=fontsize)
ax.set_title(title, fontsize=fontsize+4)
ax.set_ylabel(ylabel, fontsize=fontsize+4)
ax.set_xlabel(xlabel, fontsize=fontsize+4)
# plt.tight_layout()
return ax
[docs]def plot_bar(y: DataFrame, ax: Any = None, labels: List[str] = [],
xlabel: str = '', ylabel: str = '', fontsize: int = 12,
title: str = '', legend: List[str] = [], loc: str = 'best',
alpha: float = .8, labelsize: int = 8, rotation: float = 0.):
"""Bar plot with annotated points
Args:
y: DataFrame of y-values, observations in rows, variables in columns
ax: Axis object
labels: List of labels to annotate
labelsize: Font size of annotation labels text
rotation: Rotate annotation labels text
xlabel, ylabel: Axis labels
fontsize: Base font size
alpha: Transparency of bars
title: Main title text
legend: List of legend names
loc: Location for legend
"""
ax = ax or plt.gca()
bars = y.plot.bar(ax=ax, width=0.8, alpha=alpha)
pts = np.ravel(bars.containers, order='F').tolist()
ax.set_title(title, fontsize=fontsize+4)
ax.xaxis.set_tick_params(rotation=0, labelsize=fontsize)
ax.yaxis.set_tick_params(rotation=0, labelsize=fontsize)
if xlabel:
ax.set_xlabel(xlabel, fontsize=fontsize+2)
if ylabel:
ax.set_ylabel(ylabel, fontsize=fontsize+2)
if loc:
if legend:
ax.legend(legend, loc=loc)
else:
ax.legend(loc=loc)
if labels:
for pt, label in zip(pts, labels):
ax.annotate(str(label),
fontsize=labelsize,
xy=(pt.get_x() + pt.get_width() / 2, pt.get_height()),
xytext=(0, 3),
textcoords="offset points",
ha='center',
va='bottom',
rotation=rotation)
# plt.tight_layout()
return ax
open_t = pd.to_datetime('1900-01-01T09:30') # usual NYSE open
close_t = pd.to_datetime('1900-01-01T16:00') # usual NYSE close
[docs]def plot_time(y1: DataFrame, y2: DataFrame | None = None, ax: Any = None,
xmin: Timestamp = open_t, xmax: Timestamp = close_t, title: str='',
marker: str = ' ', fontsize: int = 8):
"""Plot lines with Timestamp time on x-axis; primary and secondary y-axis
Args:
y1 : DataFrame to plot on left axis
y2 : DataFrame (or None) to plot on right axis
ax : Axes object to plot in
xmin : Left-most x-axis time, None to include all
xmax : Right-most x-axis time, None to include all
marker : Style of marker to plot
fontsize : Font size of tick labels
"""
ax = ax or plt.gca()
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
cn = 0 # to cycle through matplotlib 'CN' color palette
left = DataFrame(y1)
if xmin:
left = left.loc[(left.index >= xmin)]
if xmax:
left = left.loc[(left.index <= xmax)]
for cn, c in enumerate(left.columns):
f = left.loc[:, c].notnull().values
if cn: # kludgy hack with time-axis
ax.plot(left.index[f], left.loc[f, c],
marker=marker, color = 'C' + str(cn))
else:
sns.lineplot(x=left.index[f], y=left.loc[f, c],
marker=marker, color='C' + str(cn),
ax=ax, legend=False)
ax.legend(left.columns, loc='upper left')
if y2 is not None:
right = DataFrame(y2)
if xmin:
right = right.loc[(right.index >= xmin)]
if xmax:
right = right.loc[(right.index <= xmax)]
bx = ax.twinx()
bx.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
for i, c in enumerate(right.columns):
g = right.loc[:, c].notnull().values
bx.plot(right.index[g], right.loc[g, c], marker=marker,
color='C' + str(cn+i+1))
bx.legend(right.columns, loc='upper right')
ax.xaxis.set_tick_params(rotation=0, labelsize=fontsize)
ax.yaxis.set_tick_params(rotation=0, labelsize=fontsize)
ax.set_title(title, fontsize=fontsize+4)
# plt.tight_layout()
return ax
##############################
#
# 2. Regression diagnostics: fitted, qq, leverage, scale
#
##############################
[docs]def plot_fitted(fitted: Series, resid: Series, n: int = 3, ax: Any = None,
title: str = "Residuals vs Fitted", fontsize: int = 12,
strftime: str = '%Y-%m-%d') -> Series:
"""Plot residuals and identify outliers
Args:
ax: Axis object
fitted: Fitted Series
residual: Residual Series
n: Number of outlier points in each end to identify
strftime: string to format time display
title: Main title text
fontsize: Base font size
"""
ax = ax or plt.gca()
outliers = np.argpartition(resid.abs().values, -n)[-n:]
sns.regplot(x=fitted, y=resid, lowess=True, ax=ax,
scatter_kws={"s": 20, 'alpha': 0.5},
line_kws={"color": "r", "lw": 1})
ax.scatter(fitted.iloc[outliers], resid.iloc[outliers], c='m', alpha=.25)
for i in outliers:
if strftime:
label = resid.index[i].strftime(strftime)
else:
label = str(resid.index[i])
ax.annotate(label, xy=(fitted.iloc[i], resid.iloc[i]), c='m', fontsize='x-small')
ax.set_title(title)
ax.set_xlabel("Fitted values")
ax.set_ylabel("Residuals")
# plt.tight_layout()
return resid.iloc[outliers].rename('outliers')
[docs]def plot_qq(resid: Series, title: str = 'Normal Q-Q', ax: Any = None,
z: float = 2.807, strftime: str = '%Y-%m-%d') -> DataFrame:
"""QQ Plot
Args:
resid: Residual Series
strftime: string to format time display
title: Main title text
z: Z-value to identify outliers
"""
pp = ProbPlot(resid, fit=True)
outliers = abs(pp.sample_quantiles) > z
ax = ax or plt.gca()
pp.qqplot(ax=ax,
color='C0',
alpha=.5)
sm.qqline(ax=ax,
line='45',
fmt='r--',
lw=1)
z = resid.sort_values().index[outliers]
for x, y, i in zip(pp.theoretical_quantiles[outliers],
pp.sample_quantiles[outliers],
z):
if strftime:
label = i.strftime(strftime)
else:
label = str(i)
ax.annotate(label, xy=(x,y), c='m', fontsize='x-small')
ax.set_title(title)
ax.set_ylabel('Standardized residuals')
# plt.tight_layout()
return DataFrame({'residuals': pp.sorted_data[outliers],
'standardized': pp.sample_quantiles[outliers]}, index=z)
[docs]def plot_scale(fitted: Series, resid: Series, ax: Any = None,
title: str = "Scale-Location", n: int = 3, alpha: float = 0.5,
strftime: str = '%Y-%m-%d'):
"""Plot scale of residuals with outliers
Args:
fitted: Fitted Series
resid: Residual Series
ax: Axis object
title: Main title text
alpha: Transparency of points
n: number of outliers in each end to identify
strftime: string to format time display
"""
ax = ax or plt.gca()
resid = np.sqrt(np.abs(resid/resid.std()))
ax.scatter(fitted,
resid,
alpha=alpha);
sns.regplot(x=fitted,
y=resid,
scatter=False,
ci=False,
lowess=True,
line_kws={'color': 'r', 'lw': 1});
ax.set_title("Scale-Location")
ax.set_ylabel('$\sqrt{|Standardized \ residuals|}$');
ax.set_xlabel('Fitted values')
outliers = np.argpartition(resid.values, -n)[-n:]
ax.scatter(fitted.iloc[outliers], resid.iloc[outliers], c='m', alpha=.25)
for i in outliers:
if strftime:
label = resid.index[i].strftime(strftime)
else:
label = str(resid.index[i])
ax.annotate(label, xy=(fitted.iloc[i], resid.iloc[i]), c='m', fontsize='x-small')
# plt.tight_layout()
return outliers
[docs]def plot_leverage(resid: Series, hat: np.array, dist: np.array, ddof: int,
title: str = "Residuals vs Leverage", ax: Any = None,
strftime='%Y-%m-%d') -> DataFrame:
"""Plot leverage and identify influential points
Args:
resid: Residual Series
hat: Hat values
dist: Distance values
ddof: Degrees of freedom of model
ax: Axis object
title: Main title text
strftime: string to format time display
"""
ax = ax or plt.gca()
s = np.sqrt(np.sum(np.array(resid)**2 * (1 - hat))/(len(hat) - ddof))
r = resid/s # studentized residual
hat = Series(hat)
sns.regplot(x=hat,
y=r,
scatter=True,
ci=False,
lowess=True,
scatter_kws={'alpha': 0.5},
line_kws={'color': 'r', 'lw': 1})
influential = np.where(dist > 1)[0]
ax.scatter(hat.iloc[influential], r.iloc[influential], c='c', alpha=.5)
annotate = np.where(dist > 0.5)[0]
for i in annotate:
if strftime:
label = r.index[i].strftime(strftime)
else:
label = str(r.index[i])
ax.annotate(label, xy=(hat.iloc[i], r.iloc[i]),
c=('r' if dist[i] > 1 else 'c'), fontsize='x-small')
ax.set_title(title)
ax.set_xlabel("Leverage")
ax.set_ylabel("Standardized residuals")
legend = None
x = np.linspace(0.001, ax.get_xlim()[1], 50)
for sign in [1, -1]: # plot Cook's Distance thresholds (both signs)
for thresh in [.5, 1]:
y = sign * np.sqrt(thresh * ddof / x) * (1 - x)
g = (y > ax.get_ylim()[0]) & (y < ax.get_ylim()[1])
if g.any():
legend, = ax.plot(x[g], y[g], color='m', lw=.5+thresh, ls='--')
ax.annotate(str(thresh), c='m', fontsize='x-small',
xy=(max(x[g]), max(y[g]) if sign < 0 else min(y[g])))
if legend:
legend.set_label("Cook's Distance")
ax.legend(loc='best')
# plt.tight_layout()
return DataFrame({'influential': resid.iloc[influential],
"cook's D": dist[influential],
"leverage": hat.iloc[influential]},
index=resid.iloc[influential].index)