mgplot.bar_plot
bar_plot.py This module contains functions to create bar plots using Matplotlib. Note: bar plots in Matplotlib are not the same as bar charts in other libraries. Bar plots are used to represent categorical data with rectangular bars. As a result, bar plots and line plots typically cannot be plotted on the same axes.
1""" 2bar_plot.py 3This module contains functions to create bar plots using Matplotlib. 4Note: bar plots in Matplotlib are not the same as bar charts in other 5libraries. Bar plots are used to represent categorical data with 6rectangular bars. As a result, bar plots and line plots typically 7cannot be plotted on the same axes. 8""" 9 10# --- imports 11from typing import Any, Final 12from collections.abc import Sequence 13 14import numpy as np 15from pandas import Series, DataFrame, Period 16import matplotlib.pyplot as plt 17from matplotlib.pyplot import Axes 18import matplotlib.patheffects as pe 19 20 21from mgplot.settings import DataT, get_setting 22from mgplot.utilities import ( 23 apply_defaults, 24 get_color_list, 25 get_axes, 26 constrain_data, 27 default_rounding, 28) 29from mgplot.kw_type_checking import ( 30 ExpectedTypeDict, 31 validate_expected, 32 report_kwargs, 33 validate_kwargs, 34) 35from mgplot.axis_utils import set_labels, map_periodindex, is_categorical 36from mgplot.keyword_names import ( 37 AX, 38 STACKED, 39 ROTATION, 40 MAX_TICKS, 41 PLOT_FROM, 42 COLOR, 43 LABEL_SERIES, 44 WIDTH, 45 ANNOTATE, 46 FONTSIZE, 47 FONTNAME, 48 ROUNDING, 49 ANNOTATE_COLOR, 50 ABOVE, 51) 52 53 54# --- constants 55 56BAR_KW_TYPES: Final[ExpectedTypeDict] = { 57 # --- options for the entire bar plot 58 AX: (Axes, type(None)), # axes to plot on, or None for new axes 59 STACKED: bool, # if True, the bars will be stacked. If False, they will be grouped. 60 MAX_TICKS: int, 61 PLOT_FROM: (int, Period, type(None)), 62 # --- options for each bar ... 63 COLOR: (str, Sequence, (str,)), 64 LABEL_SERIES: (bool, Sequence, (bool,)), 65 WIDTH: (float, int), 66 # - options for bar annotations 67 ANNOTATE: (type(None), bool), # None, True 68 FONTSIZE: (int, float, str), 69 FONTNAME: (str), 70 ROUNDING: int, 71 ROTATION: (int, float), # rotation of annotations in degrees 72 ANNOTATE_COLOR: (str, type(None)), # color of annotations 73 ABOVE: bool, # if True, annotations are above the bar 74} 75validate_expected(BAR_KW_TYPES, "bar_plot") 76 77 78# --- functions 79def annotate_bars( 80 series: Series, 81 offset: float, 82 base: np.ndarray[tuple[int, ...], np.dtype[Any]], 83 axes: Axes, 84 **anno_kwargs, 85) -> None: 86 """Bar plot annotations. 87 88 Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs. 89 """ 90 91 # --- only annotate in limited circumstances 92 if ANNOTATE not in anno_kwargs or not anno_kwargs[ANNOTATE]: 93 return 94 max_annotations = 30 95 if len(series) > max_annotations: 96 return 97 98 # --- internal logic check 99 if len(base) != len(series): 100 print( 101 f"Warning: base array length {len(base)} does not match series length {len(series)}." 102 ) 103 return 104 105 # --- assemble the annotation parameters 106 above: Final[bool | None] = anno_kwargs.get(ABOVE, False) # None is also False-ish 107 annotate_style = { 108 FONTSIZE: anno_kwargs.get(FONTSIZE), 109 FONTNAME: anno_kwargs.get(FONTNAME), 110 COLOR: anno_kwargs.get(COLOR), 111 ROTATION: anno_kwargs.get(ROTATION), 112 } 113 rounding = default_rounding(series=series, provided=anno_kwargs.get(ROUNDING, None)) 114 adjustment = (series.max() - series.min()) * 0.02 115 zero_correction = series.index.min() 116 117 # --- annotate each bar 118 for index, value in zip(series.index.astype(int), series): # mypy syntactic sugar 119 position = base[index - zero_correction] + ( 120 adjustment if value >= 0 else -adjustment 121 ) 122 if above: 123 position += value 124 text = axes.text( 125 x=index + offset, 126 y=position, 127 s=f"{value:.{rounding}f}", 128 ha="center", 129 va="bottom" if value >= 0 else "top", 130 **annotate_style, 131 ) 132 if not above and "foreground" in anno_kwargs: 133 # apply a stroke-effect to within bar annotations 134 # to make them more readable with very small bars. 135 text.set_path_effects( 136 [pe.withStroke(linewidth=2, foreground=anno_kwargs.get("foreground"))] 137 ) 138 139 140def grouped(axes, df: DataFrame, anno_args, **kwargs) -> None: 141 """ 142 plot a grouped bar plot 143 """ 144 145 series_count = len(df.columns) 146 147 for i, col in enumerate(df.columns): 148 series = df[col] 149 if series.isnull().all(): 150 continue 151 width = kwargs["width"][i] 152 if width < 0 or width > 1: 153 width = 0.8 154 adjusted_width = width / series_count # 0.8 155 # far-left + margin + halfway through one grouped column 156 left = -0.5 + ((1 - width) / 2.0) + (adjusted_width / 2.0) 157 offset = left + (i * adjusted_width) 158 foreground = kwargs["color"][i] 159 axes.bar( 160 x=series.index + offset, 161 height=series, 162 color=foreground, 163 width=adjusted_width, 164 label=col if kwargs[LABEL_SERIES][i] else f"_{col}_", 165 ) 166 annotate_bars( 167 series=series, 168 offset=offset, 169 base=np.zeros(len(series)), 170 axes=axes, 171 foreground=foreground, 172 **anno_args, 173 ) 174 175 176def stacked(axes, df: DataFrame, anno_args, **kwargs) -> None: 177 """ 178 plot a stacked bar plot 179 """ 180 181 series_count = len(df) 182 base_plus: np.ndarray[tuple[int, ...], np.dtype[np.float64]] = np.zeros( 183 shape=series_count, dtype=np.float64 184 ) 185 base_minus: np.ndarray[tuple[int, ...], np.dtype[np.float64]] = np.zeros( 186 shape=series_count, dtype=np.float64 187 ) 188 for i, col in enumerate(df.columns): 189 series = df[col] 190 base = np.where(series >= 0, base_plus, base_minus) 191 foreground = kwargs["color"][i] 192 axes.bar( 193 x=series.index, 194 height=series, 195 bottom=base, 196 color=foreground, 197 width=kwargs[WIDTH][i], 198 label=col if kwargs[LABEL_SERIES][i] else f"_{col}_", 199 ) 200 annotate_bars( 201 series=series, 202 offset=0, 203 base=base, 204 axes=axes, 205 foreground=foreground, 206 **anno_args, 207 ) 208 base_plus += np.where(series >= 0, series, 0) 209 base_minus += np.where(series < 0, series, 0) 210 211 212def bar_plot( 213 data: DataT, 214 **kwargs, 215) -> Axes: 216 """ 217 Create a bar plot from the given data. Each column in the DataFrame 218 will be stacked on top of each other, with positive values above 219 zero and negative values below zero. 220 221 Parameters 222 - data: Series - The data to plot. Can be a DataFrame or a Series. 223 - **kwargs: dict Additional keyword arguments for customization. 224 # --- options for the entire bar plot 225 ax: Axes - axes to plot on, or None for new axes 226 stacked: bool - if True, the bars will be stacked. If False, they will be grouped. 227 max_ticks: int - maximum number of ticks on the x-axis (for PeriodIndex only) 228 plot_from: int | PeriodIndex - if provided, the plot will start from this index. 229 # --- options for each bar ... 230 color: str | list[str] - the color of the bars (or separate colors for each series 231 label_series: bool | list[bool] - if True, the series will be labeled in the legend 232 width: float | list[float] - the width of the bars 233 # - options for bar annotations 234 annotate: bool - If True them annotate the bars with their values. 235 fontsize: int | float | str - font size of the annotations 236 fontname: str - font name of the annotations 237 rounding: int - number of decimal places to round to 238 annotate_color: str - color of annotations 239 rotation: int | float - rotation of annotations in degrees 240 above: bool - if True, annotations are above the bar, else within the bar 241 242 Note: This function does not assume all data is timeseries with a PeriodIndex, 243 244 Returns 245 - axes: Axes - The axes for the plot. 246 """ 247 248 # --- check the kwargs 249 me = "bar_plot" 250 report_kwargs(called_from=me, **kwargs) 251 kwargs = validate_kwargs(BAR_KW_TYPES, me, **kwargs) 252 253 # --- get the data 254 # no call to check_clean_timeseries here, as bar plots are not 255 # necessarily timeseries data. If the data is a Series, it will be 256 # converted to a DataFrame with a single column. 257 df = DataFrame(data) # really we are only plotting DataFrames 258 df, kwargs = constrain_data(df, **kwargs) 259 item_count = len(df.columns) 260 261 # --- deal with complete PeriodIdex indicies 262 if not is_categorical(df): 263 print( 264 "Warning: bar_plot is not designed for incomplete or non-categorical data indexes." 265 ) 266 saved_pi = map_periodindex(df) 267 if saved_pi is not None: 268 df = saved_pi[0] # extract the reindexed DataFrame from the PeriodIndex 269 270 # --- set up the default arguments 271 chart_defaults: dict[str, Any] = { 272 STACKED: False, 273 MAX_TICKS: 10, 274 LABEL_SERIES: item_count > 1, 275 } 276 chart_args = {k: kwargs.get(k, v) for k, v in chart_defaults.items()} 277 278 bar_defaults: dict[str, Any] = { 279 COLOR: get_color_list(item_count), 280 WIDTH: get_setting("bar_width"), 281 LABEL_SERIES: (item_count > 1), 282 } 283 above = kwargs.get(ABOVE, False) 284 anno_args = { 285 ANNOTATE: kwargs.get(ANNOTATE, False), 286 FONTSIZE: kwargs.get(FONTSIZE, "small"), 287 FONTNAME: kwargs.get(FONTNAME, "Helvetica"), 288 ROTATION: kwargs.get(ROTATION, 0), 289 ROUNDING: kwargs.get(ROUNDING, True), 290 COLOR: kwargs.get(ANNOTATE_COLOR, "black" if above else "white"), 291 ABOVE: above, 292 } 293 bar_args, remaining_kwargs = apply_defaults(item_count, bar_defaults, kwargs) 294 295 # --- plot the data 296 axes, _rkwargs = get_axes(**remaining_kwargs) 297 if chart_args[STACKED]: 298 stacked(axes, df, anno_args, **bar_args) 299 else: 300 grouped(axes, df, anno_args, **bar_args) 301 302 # --- handle complete periodIndex data and label rotation 303 if saved_pi is not None: 304 set_labels(axes, saved_pi[1], chart_args["max_ticks"]) 305 else: 306 plt.xticks(rotation=90) 307 308 return axes
80def annotate_bars( 81 series: Series, 82 offset: float, 83 base: np.ndarray[tuple[int, ...], np.dtype[Any]], 84 axes: Axes, 85 **anno_kwargs, 86) -> None: 87 """Bar plot annotations. 88 89 Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs. 90 """ 91 92 # --- only annotate in limited circumstances 93 if ANNOTATE not in anno_kwargs or not anno_kwargs[ANNOTATE]: 94 return 95 max_annotations = 30 96 if len(series) > max_annotations: 97 return 98 99 # --- internal logic check 100 if len(base) != len(series): 101 print( 102 f"Warning: base array length {len(base)} does not match series length {len(series)}." 103 ) 104 return 105 106 # --- assemble the annotation parameters 107 above: Final[bool | None] = anno_kwargs.get(ABOVE, False) # None is also False-ish 108 annotate_style = { 109 FONTSIZE: anno_kwargs.get(FONTSIZE), 110 FONTNAME: anno_kwargs.get(FONTNAME), 111 COLOR: anno_kwargs.get(COLOR), 112 ROTATION: anno_kwargs.get(ROTATION), 113 } 114 rounding = default_rounding(series=series, provided=anno_kwargs.get(ROUNDING, None)) 115 adjustment = (series.max() - series.min()) * 0.02 116 zero_correction = series.index.min() 117 118 # --- annotate each bar 119 for index, value in zip(series.index.astype(int), series): # mypy syntactic sugar 120 position = base[index - zero_correction] + ( 121 adjustment if value >= 0 else -adjustment 122 ) 123 if above: 124 position += value 125 text = axes.text( 126 x=index + offset, 127 y=position, 128 s=f"{value:.{rounding}f}", 129 ha="center", 130 va="bottom" if value >= 0 else "top", 131 **annotate_style, 132 ) 133 if not above and "foreground" in anno_kwargs: 134 # apply a stroke-effect to within bar annotations 135 # to make them more readable with very small bars. 136 text.set_path_effects( 137 [pe.withStroke(linewidth=2, foreground=anno_kwargs.get("foreground"))] 138 )
Bar plot annotations.
Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs.
141def grouped(axes, df: DataFrame, anno_args, **kwargs) -> None: 142 """ 143 plot a grouped bar plot 144 """ 145 146 series_count = len(df.columns) 147 148 for i, col in enumerate(df.columns): 149 series = df[col] 150 if series.isnull().all(): 151 continue 152 width = kwargs["width"][i] 153 if width < 0 or width > 1: 154 width = 0.8 155 adjusted_width = width / series_count # 0.8 156 # far-left + margin + halfway through one grouped column 157 left = -0.5 + ((1 - width) / 2.0) + (adjusted_width / 2.0) 158 offset = left + (i * adjusted_width) 159 foreground = kwargs["color"][i] 160 axes.bar( 161 x=series.index + offset, 162 height=series, 163 color=foreground, 164 width=adjusted_width, 165 label=col if kwargs[LABEL_SERIES][i] else f"_{col}_", 166 ) 167 annotate_bars( 168 series=series, 169 offset=offset, 170 base=np.zeros(len(series)), 171 axes=axes, 172 foreground=foreground, 173 **anno_args, 174 )
plot a grouped bar plot
177def stacked(axes, df: DataFrame, anno_args, **kwargs) -> None: 178 """ 179 plot a stacked bar plot 180 """ 181 182 series_count = len(df) 183 base_plus: np.ndarray[tuple[int, ...], np.dtype[np.float64]] = np.zeros( 184 shape=series_count, dtype=np.float64 185 ) 186 base_minus: np.ndarray[tuple[int, ...], np.dtype[np.float64]] = np.zeros( 187 shape=series_count, dtype=np.float64 188 ) 189 for i, col in enumerate(df.columns): 190 series = df[col] 191 base = np.where(series >= 0, base_plus, base_minus) 192 foreground = kwargs["color"][i] 193 axes.bar( 194 x=series.index, 195 height=series, 196 bottom=base, 197 color=foreground, 198 width=kwargs[WIDTH][i], 199 label=col if kwargs[LABEL_SERIES][i] else f"_{col}_", 200 ) 201 annotate_bars( 202 series=series, 203 offset=0, 204 base=base, 205 axes=axes, 206 foreground=foreground, 207 **anno_args, 208 ) 209 base_plus += np.where(series >= 0, series, 0) 210 base_minus += np.where(series < 0, series, 0)
plot a stacked bar plot
213def bar_plot( 214 data: DataT, 215 **kwargs, 216) -> Axes: 217 """ 218 Create a bar plot from the given data. Each column in the DataFrame 219 will be stacked on top of each other, with positive values above 220 zero and negative values below zero. 221 222 Parameters 223 - data: Series - The data to plot. Can be a DataFrame or a Series. 224 - **kwargs: dict Additional keyword arguments for customization. 225 # --- options for the entire bar plot 226 ax: Axes - axes to plot on, or None for new axes 227 stacked: bool - if True, the bars will be stacked. If False, they will be grouped. 228 max_ticks: int - maximum number of ticks on the x-axis (for PeriodIndex only) 229 plot_from: int | PeriodIndex - if provided, the plot will start from this index. 230 # --- options for each bar ... 231 color: str | list[str] - the color of the bars (or separate colors for each series 232 label_series: bool | list[bool] - if True, the series will be labeled in the legend 233 width: float | list[float] - the width of the bars 234 # - options for bar annotations 235 annotate: bool - If True them annotate the bars with their values. 236 fontsize: int | float | str - font size of the annotations 237 fontname: str - font name of the annotations 238 rounding: int - number of decimal places to round to 239 annotate_color: str - color of annotations 240 rotation: int | float - rotation of annotations in degrees 241 above: bool - if True, annotations are above the bar, else within the bar 242 243 Note: This function does not assume all data is timeseries with a PeriodIndex, 244 245 Returns 246 - axes: Axes - The axes for the plot. 247 """ 248 249 # --- check the kwargs 250 me = "bar_plot" 251 report_kwargs(called_from=me, **kwargs) 252 kwargs = validate_kwargs(BAR_KW_TYPES, me, **kwargs) 253 254 # --- get the data 255 # no call to check_clean_timeseries here, as bar plots are not 256 # necessarily timeseries data. If the data is a Series, it will be 257 # converted to a DataFrame with a single column. 258 df = DataFrame(data) # really we are only plotting DataFrames 259 df, kwargs = constrain_data(df, **kwargs) 260 item_count = len(df.columns) 261 262 # --- deal with complete PeriodIdex indicies 263 if not is_categorical(df): 264 print( 265 "Warning: bar_plot is not designed for incomplete or non-categorical data indexes." 266 ) 267 saved_pi = map_periodindex(df) 268 if saved_pi is not None: 269 df = saved_pi[0] # extract the reindexed DataFrame from the PeriodIndex 270 271 # --- set up the default arguments 272 chart_defaults: dict[str, Any] = { 273 STACKED: False, 274 MAX_TICKS: 10, 275 LABEL_SERIES: item_count > 1, 276 } 277 chart_args = {k: kwargs.get(k, v) for k, v in chart_defaults.items()} 278 279 bar_defaults: dict[str, Any] = { 280 COLOR: get_color_list(item_count), 281 WIDTH: get_setting("bar_width"), 282 LABEL_SERIES: (item_count > 1), 283 } 284 above = kwargs.get(ABOVE, False) 285 anno_args = { 286 ANNOTATE: kwargs.get(ANNOTATE, False), 287 FONTSIZE: kwargs.get(FONTSIZE, "small"), 288 FONTNAME: kwargs.get(FONTNAME, "Helvetica"), 289 ROTATION: kwargs.get(ROTATION, 0), 290 ROUNDING: kwargs.get(ROUNDING, True), 291 COLOR: kwargs.get(ANNOTATE_COLOR, "black" if above else "white"), 292 ABOVE: above, 293 } 294 bar_args, remaining_kwargs = apply_defaults(item_count, bar_defaults, kwargs) 295 296 # --- plot the data 297 axes, _rkwargs = get_axes(**remaining_kwargs) 298 if chart_args[STACKED]: 299 stacked(axes, df, anno_args, **bar_args) 300 else: 301 grouped(axes, df, anno_args, **bar_args) 302 303 # --- handle complete periodIndex data and label rotation 304 if saved_pi is not None: 305 set_labels(axes, saved_pi[1], chart_args["max_ticks"]) 306 else: 307 plt.xticks(rotation=90) 308 309 return axes
Create a bar plot from the given data. Each column in the DataFrame will be stacked on top of each other, with positive values above zero and negative values below zero.
Parameters
- data: Series - The data to plot. Can be a DataFrame or a Series.
- **kwargs: dict Additional keyword arguments for customization.
--- options for the entire bar plot
ax: Axes - axes to plot on, or None for new axes stacked: bool - if True, the bars will be stacked. If False, they will be grouped. max_ticks: int - maximum number of ticks on the x-axis (for PeriodIndex only) plot_from: int | PeriodIndex - if provided, the plot will start from this index.
--- options for each bar ...
color: str | list[str] - the color of the bars (or separate colors for each series label_series: bool | list[bool] - if True, the series will be labeled in the legend width: float | list[float] - the width of the bars
- options for bar annotations
annotate: bool - If True them annotate the bars with their values. fontsize: int | float | str - font size of the annotations fontname: str - font name of the annotations rounding: int - number of decimal places to round to annotate_color: str - color of annotations rotation: int | float - rotation of annotations in degrees above: bool - if True, annotations are above the bar, else within the bar
Note: This function does not assume all data is timeseries with a PeriodIndex,
Returns
- axes: Axes - The axes for the plot.