Skip to content

Commit

Permalink
feat(python): support creation of sparklines when exporting Excel t…
Browse files Browse the repository at this point in the history
…ables (#7333)
  • Loading branch information
alexander-beedie authored Mar 3, 2023
1 parent b542fee commit 5bf4f37
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 23 deletions.
95 changes: 76 additions & 19 deletions py-polars/polars/internals/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from polars.internals.dataframe.groupby import DynamicGroupBy, GroupBy, RollingGroupBy
from polars.internals.io_excel import (
_xl_column_range,
_xl_inject_sparklines,
_xl_setup_table_columns,
_xl_setup_table_options,
_xl_setup_workbook,
Expand Down Expand Up @@ -2401,11 +2402,12 @@ def write_excel(
position: tuple[int, int] | str = "A1",
table_style: str | dict[str, Any] | None = None,
table_name: str | None = None,
column_formats: dict[str, str] | None = None,
column_widths: dict[str, int] | None = None,
column_totals: dict[str, str] | Sequence[str] | bool | None = None,
column_formats: dict[str, str] | None = None,
conditional_formats: dict[str, str | dict[str, Any]] | None = None,
dtype_formats: dict[OneOrMoreDataTypes, str] | None = None,
sparklines: dict[str, Sequence[str] | dict[str, Any]] | None = None,
float_precision: int = 3,
has_header: bool = True,
autofilter: bool = True,
Expand All @@ -2414,7 +2416,7 @@ def write_excel(
hide_gridlines: bool = False,
) -> Workbook:
"""
Write data to a table in an Excel workbook/worksheet.
Write frame data to a table in an Excel workbook/worksheet.
Parameters
----------
Expand All @@ -2428,34 +2430,46 @@ def write_excel(
position
Table position in Excel notation (eg: "A1"), or a (row,col) integer tuple.
table_style
A named Excel table style, such as "Table Style Medium 4", or a
table style/option dictionary containing one or more of the following keys:
A named Excel table style, such as "Table Style Medium 4", or a dictionary
of {"option":bool,} containing one or more of the following keys:
"style", "first_column", "last_column", "banded_columns, "banded_rows".
table_name
Name of the output table object in the worksheet.
column_formats
A {"col":"fmt",} dict matching specific columns to a particular Excel format
string, such as "dd/mm/yyyy", "0.00%", "($#,##0_);[Red]($#,##0)", etc.
(Formats defined here will override those defined in ``dtype_formats``).
Name of the output table object in the worksheet; can be referred to in
the sheet by formulae/charts, or by subsequent xlsxwriter operations.
column_widths
A {"col":width,} dict that sets (or overrides if autofitting) column widths
in integer pixel units.
column_totals
Add a total row. If True, all numeric columns will have an associated total
using "sum". If a list of colnames, only those listed will have a "sum"
using "sum". If given a list of colnames, those listed will have a "sum"
total. For more control, pass a {"col":"fn",} dict. Valid functions include:
"average", "count_nums", "count", "max", "min", "std_dev", "sum", "var".
column_formats
A {"col":"fmt",} dict matching specific columns to a particular Excel format
string, such as "dd/mm/yyyy", "0.00%", "($#,##0_);[Red]($#,##0)", etc.
(Formats defined here will override those defined in ``dtype_formats``).
conditional_formats
A {"col":"typename",} or {"col":definition,} dict applying conditional
formats to specific columns. If supplying a typename, should be one of the
recognised xlsxwriter types such as "3_color_scale", "data_bar", etc. When
supplying the full definition you have complete flexibility to apply any
supported conditional format, including icon sets, formulae, etc.
A {"col":str,} or {"col":options,} dict that defines conditional formatting
for the specified columns. If supplying a string typename, should be one of
the recognised xlsxwriter types such as "3_color_scale", "data_bar", etc.
If supplying the full definition dictionary you have complete flexibility to
apply any supported conditional format, including icon sets, formulae, etc.
dtype_formats
A {dtype:"fmt",} dict that sets the default Excel format for the given
dtype. (This is overridden on a per-column basis by ``column_formats``). It
is also valid to use dtype groups such as ``polars.datatypes.FLOAT_DTYPES``
as the dtype/format key, to simplify setting uniform int/float formats.
sparklines
A {"col":colnames,} or {"col":params,} dict that defines one or more
sparklines to be written into a new column in the table. If passing a
list of colnames (used as the source of the sparkline data) the default
sparkline settings are used (eg: will be a line with no markers). For more
control an xlsxwriter-compliant parameter dictionary can be supplied; in
this case three additional polars-specific keys are available: "columns",
"insert_before", and "insert_after". These allow you to define the source
columns and position the sparkline(s) with respect to other table columns.
If no position directive is given, sparklines are added to the end of the
table in the order in which they are defined (eg: to the far right).
float_precision
Default number of decimals displayed for floating point columns (note that
this is purely a formatting directive; the actual values are not rounded).
Expand All @@ -2472,11 +2486,19 @@ def write_excel(
Notes
-----
All conditional formatting dictionaries should provide xlsxwriter-compatible
Conditional formatting parameter dicts should provide xlsxwriter-compatible
definitions; polars will take care of how/where they are applied on the
worksheet with respect to the column position. For more details, see:
worksheet with respect to the column position. For supported options, see:
https://xlsxwriter.readthedocs.io/working_with_conditional_formats.html
Similarly for sparklines, any parameter definition dictionary should contain
xlsxwriter-compatible key/values, as well as a mandatory polars "columns" key
that defines the sparkline source data; these source cols should be adjacent to
each other. Two other polars-specific keys are available to help define where
the sparkline appears in the table: "insert_after", and "insert_before". The
value associated with these keys should be the name of a column in the table.
https://xlsxwriter.readthedocs.io/working_with_sparklines.html
Examples
--------
>>> from random import uniform
Expand Down Expand Up @@ -2521,7 +2543,7 @@ def write_excel(
... position=(3, 1), # specify position as (row,col) coordinates
... conditional_formats={"num": "3_color_scale", "val": "data_bar"},
... table_style="Table Style Medium 4",
... ) # doctest: +IGNORE_RESULT
... )
...
... # advanced conditional formatting, custom styles
... df.write_excel(
Expand Down Expand Up @@ -2566,6 +2588,36 @@ def write_excel(
... ws.write(len(df) + 6, 1, "Customised conditional formatting", fmt_title)
...
Export a table containing two different types of sparklines. Use default
options for the "trend" sparkline and customised options (and positioning)
for the "+/-" win_loss sparkline, with non-default integer dtype formatting,
column totals, and hidden worksheet gridlines:
>>> df = pl.DataFrame(
... {
... "id": ["aaa", "bbb", "ccc", "ddd", "eee"],
... "q1": [100, 55, -20, 0, 35],
... "q2": [30, -10, 15, 60, 20],
... "q3": [-50, 0, 40, 80, 80],
... "q4": [75, 55, 25, -10, -55],
... }
... )
>>> from polars.datatypes import INTEGER_DTYPES
>>> df.write_excel( # doctest: +SKIP
... table_style="Table Style Light 2",
... dtype_formats={INTEGER_DTYPES: "#,##0_);(#,##0)"},
... sparklines={
... "trend": ["q1", "q2", "q3", "q4"],
... "+/-": {
... "columns": ["q1", "q2", "q3", "q4"],
... "insert_after": "id",
... "type": "win_loss",
... },
... },
... column_totals=["q1", "q2", "q3", "q4"],
... hide_gridlines=True,
... )
"""
try:
import xlsxwriter
Expand All @@ -2581,13 +2633,14 @@ def write_excel(

# setup table format/columns
table_style, table_options = _xl_setup_table_options(table_style)
table_columns = _xl_setup_table_columns(
table_columns, df = _xl_setup_table_columns(
df=df,
wb=wb,
column_formats=column_formats,
column_totals=column_totals,
dtype_formats=dtype_formats,
float_precision=float_precision,
sparklines=sparklines,
)

# normalise cell refs (eg: "B3" => (2,1)) and establish table start/finish,
Expand Down Expand Up @@ -2650,6 +2703,10 @@ def write_excel(
elif options:
ws.set_column(col_idx, col_idx, None, None, options)

# inject any sparklines into the table
for col, params in (sparklines or {}).items():
_xl_inject_sparklines(ws, df, table_start, col, has_header, params)

if can_close:
wb.close()
return wb
Expand Down
91 changes: 88 additions & 3 deletions py-polars/polars/internals/io_excel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Sequence,
)

import polars.internals as pli
from polars.datatypes import (
FLOAT_DTYPES,
INTEGER_DTYPES,
Expand All @@ -16,12 +17,12 @@
Datetime,
Time,
)
from polars.exceptions import DuplicateError

if TYPE_CHECKING:
from xlsxwriter import Workbook
from xlsxwriter.worksheet import Worksheet

import polars.internals as pli
from polars.datatypes import OneOrMoreDataTypes, PolarsDataType


Expand All @@ -36,6 +37,83 @@
_XL_DEFAULT_DTYPE_FORMATS_[tp] = _XL_DEFAULT_INTEGER_FORMAT_


def _xl_inject_dummy_table_columns(
df: pli.DataFrame, options: dict[str, Sequence[str] | dict[str, Any]]
) -> pli.DataFrame:
"""Insert dummy frame columns in order to create empty/named table columns."""
df_original_columns = set(df.columns)
df_select_cols = df.columns.copy()

for col, definition in options.items():
if col in df_original_columns:
raise DuplicateError(f"Cannot create a second {col!r} column")
elif not isinstance(definition, dict):
df_select_cols.append(col)
else:
insert_after = definition.get("insert_after")
insert_before = definition.get("insert_before")
if insert_after is None and insert_before is None:
df_select_cols.append(col)
else:
insert_idx = (
df_select_cols.index(insert_after) + 1 # type: ignore[arg-type]
if insert_before is None
else df_select_cols.index(insert_before)
)
df_select_cols.insert(insert_idx, col)

df = df.select(
[
(col if col in df_original_columns else pli.lit("").alias(col))
for col in df_select_cols
]
)
return df


def _xl_inject_sparklines(
ws: Worksheet,
df: pli.DataFrame,
table_start: tuple[int, int],
col: str,
has_header: bool,
params: Sequence[str] | dict[str, Any],
) -> None:
"""Inject sparklines into (previously-created) empty table columns."""
from xlsxwriter.utility import xl_rowcol_to_cell

data_cols = params.get("columns") if isinstance(params, dict) else params
if not data_cols:
raise ValueError("Supplying 'columns' is mandatory for sparklines")

data_idxs = sorted(df.find_idx_by_name(col) for col in data_cols)
if data_idxs != sorted(range(min(data_idxs), max(data_idxs) + 1)):
raise RuntimeError("sparkline data range/cols must be contiguous")

spk_row, spk_col, _, _ = _xl_column_range(df, table_start, col, has_header)
data_start_col = table_start[1] + data_idxs[0]
data_end_col = table_start[1] + data_idxs[-1]

if not isinstance(params, dict):
options = {}
else:
# strip polars-specific params before passing to xlsxwriter
options = {
name: val
for name, val in params.items()
if name not in ("columns", "insert_after", "insert_before")
}
if "negative_points" not in options:
options["negative_points"] = options.get("type") in ("column", "win_loss")

for _ in range(len(df)):
data_start = xl_rowcol_to_cell(spk_row, data_start_col)
data_end = xl_rowcol_to_cell(spk_row, data_end_col)
options["range"] = f"{data_start}:{data_end}"
ws.add_sparkline(spk_row, spk_col, options)
spk_row += 1


def _xl_setup_workbook(
workbook: Workbook | BytesIO | Path | str | None, worksheet: str | None = None
) -> tuple[Workbook, Worksheet, bool]:
Expand Down Expand Up @@ -74,8 +152,9 @@ def _xl_setup_table_columns(
column_formats: dict[str, str] | None = None,
column_totals: dict[str, str] | Sequence[str] | bool | None = None,
dtype_formats: dict[OneOrMoreDataTypes, str] | None = None,
sparklines: dict[str, Sequence[str] | dict[str, Any]] | None = None,
float_precision: int = 3,
) -> list[dict[str, Any]]:
) -> tuple[list[dict[str, Any]], pli.DataFrame]:
"""Setup and unify all column-related formatting/defaults."""
total_funcs = (
{col: "sum" for col in column_totals}
Expand All @@ -88,6 +167,10 @@ def _xl_setup_table_columns(
if isinstance(tp, (tuple, frozenset)):
dtype_formats.update(dict.fromkeys(tp, dtype_formats.pop(tp)))

# inject sparkline placeholder(s)
if sparklines:
df = _xl_inject_dummy_table_columns(df, sparklines)

# default float format
zeros = "0" * float_precision
fmt_float = (
Expand Down Expand Up @@ -125,14 +208,16 @@ def _xl_setup_table_columns(
fmt["num_format"] = dtype_formats[tp]
column_formats[col] = wb.add_format(fmt)

return [
# assemble table columns
table_columns = [
{
"header": col,
"format": column_formats.get(col),
"total_function": total_funcs.get(col),
}
for col in df.columns
]
return table_columns, df


def _xl_setup_table_options(
Expand Down
56 changes: 55 additions & 1 deletion py-polars/tests/unit/io/test_excel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest

import polars as pl
from polars.datatypes import FLOAT_DTYPES
from polars.datatypes import FLOAT_DTYPES, INTEGER_DTYPES
from polars.testing import assert_frame_equal


Expand Down Expand Up @@ -120,3 +120,57 @@ def test_excel_round_trip(write_params: dict[str, Any]) -> None:
xldf = xldf[:3]

assert_frame_equal(df, xldf)


def test_excel_sparklines() -> None:
from xlsxwriter import Workbook

# note that we don't (quite) expect sparkline export to round-trip
# as we have to inject additional empty columns to hold them...
df = pl.DataFrame(
{
"id": ["aaa", "bbb", "ccc", "ddd", "eee"],
"q1": [100, 55, -20, 0, 35],
"q2": [30, -10, 15, 60, 20],
"q3": [-50, 0, 40, 80, 80],
"q4": [75, 55, 25, -10, -55],
}
)

# also: confirm that we can use a Workbook directly with "write_excel"
xls = BytesIO()
with Workbook(xls) as wb:
df.write_excel(
workbook=wb,
worksheet="frame_data",
table_style="Table Style Light 2",
dtype_formats={INTEGER_DTYPES: "#,##0_);(#,##0)"},
sparklines={
"trend": ["q1", "q2", "q3", "q4"],
"+/-": {
"columns": ["q1", "q2", "q3", "q4"],
"insert_after": "id",
"type": "win_loss",
},
},
hide_gridlines=True,
)

xldf = pl.read_excel(file=xls, sheet_name="frame_data") # type: ignore[call-overload]
# ┌──────┬──────┬─────┬─────┬─────┬─────┬───────┐
# │ id ┆ +/- ┆ q1 ┆ q2 ┆ q3 ┆ q4 ┆ trend │
# │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
# │ str ┆ str ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ str │
# ╞══════╪══════╪═════╪═════╪═════╪═════╪═══════╡
# │ aaa ┆ null ┆ 100 ┆ 30 ┆ -50 ┆ 75 ┆ null │
# │ bbb ┆ null ┆ 55 ┆ -10 ┆ 0 ┆ 55 ┆ null │
# │ ccc ┆ null ┆ -20 ┆ 15 ┆ 40 ┆ 25 ┆ null │
# │ ddd ┆ null ┆ 0 ┆ 60 ┆ 80 ┆ -10 ┆ null │
# │ eee ┆ null ┆ 35 ┆ 20 ┆ 80 ┆ -55 ┆ null │
# └──────┴──────┴─────┴─────┴─────┴─────┴───────┘

for sparkline_col in ("+/-", "trend"):
assert set(xldf[sparkline_col]) == {None}

assert xldf.columns == ["id", "+/-", "q1", "q2", "q3", "q4", "trend"]
assert_frame_equal(df, xldf.drop("+/-", "trend"))

0 comments on commit 5bf4f37

Please sign in to comment.