diff --git a/py-polars/polars/internals/dataframe/frame.py b/py-polars/polars/internals/dataframe/frame.py index 24da2819c633..3001d5562679 100644 --- a/py-polars/polars/internals/dataframe/frame.py +++ b/py-polars/polars/internals/dataframe/frame.py @@ -70,6 +70,7 @@ from polars.internals.dataframe.groupby import DynamicGroupBy, GroupBy, RollingGroupBy from polars.internals.io_excel import ( _xl_column_range, + _xl_inject_sparklines, _xl_setup_table_columns, _xl_setup_table_options, _xl_setup_workbook, @@ -2401,11 +2402,12 @@ def write_excel( position: tuple[int, int] | str = "A1", table_style: str | dict[str, Any] | None = None, table_name: str | None = None, - column_formats: dict[str, str] | None = None, column_widths: dict[str, int] | None = None, column_totals: dict[str, str] | Sequence[str] | bool | None = None, + column_formats: dict[str, str] | None = None, conditional_formats: dict[str, str | dict[str, Any]] | None = None, dtype_formats: dict[OneOrMoreDataTypes, str] | None = None, + sparklines: dict[str, Sequence[str] | dict[str, Any]] | None = None, float_precision: int = 3, has_header: bool = True, autofilter: bool = True, @@ -2414,7 +2416,7 @@ def write_excel( hide_gridlines: bool = False, ) -> Workbook: """ - Write data to a table in an Excel workbook/worksheet. + Write frame data to a table in an Excel workbook/worksheet. Parameters ---------- @@ -2428,34 +2430,46 @@ def write_excel( position Table position in Excel notation (eg: "A1"), or a (row,col) integer tuple. table_style - A named Excel table style, such as "Table Style Medium 4", or a - table style/option dictionary containing one or more of the following keys: + A named Excel table style, such as "Table Style Medium 4", or a dictionary + of {"option":bool,} containing one or more of the following keys: "style", "first_column", "last_column", "banded_columns, "banded_rows". table_name - Name of the output table object in the worksheet. - column_formats - A {"col":"fmt",} dict matching specific columns to a particular Excel format - string, such as "dd/mm/yyyy", "0.00%", "($#,##0_);[Red]($#,##0)", etc. - (Formats defined here will override those defined in ``dtype_formats``). + Name of the output table object in the worksheet; can be referred to in + the sheet by formulae/charts, or by subsequent xlsxwriter operations. column_widths A {"col":width,} dict that sets (or overrides if autofitting) column widths in integer pixel units. column_totals Add a total row. If True, all numeric columns will have an associated total - using "sum". If a list of colnames, only those listed will have a "sum" + using "sum". If given a list of colnames, those listed will have a "sum" total. For more control, pass a {"col":"fn",} dict. Valid functions include: "average", "count_nums", "count", "max", "min", "std_dev", "sum", "var". + column_formats + A {"col":"fmt",} dict matching specific columns to a particular Excel format + string, such as "dd/mm/yyyy", "0.00%", "($#,##0_);[Red]($#,##0)", etc. + (Formats defined here will override those defined in ``dtype_formats``). conditional_formats - A {"col":"typename",} or {"col":definition,} dict applying conditional - formats to specific columns. If supplying a typename, should be one of the - recognised xlsxwriter types such as "3_color_scale", "data_bar", etc. When - supplying the full definition you have complete flexibility to apply any - supported conditional format, including icon sets, formulae, etc. + A {"col":str,} or {"col":options,} dict that defines conditional formatting + for the specified columns. If supplying a string typename, should be one of + the recognised xlsxwriter types such as "3_color_scale", "data_bar", etc. + If supplying the full definition dictionary you have complete flexibility to + apply any supported conditional format, including icon sets, formulae, etc. dtype_formats A {dtype:"fmt",} dict that sets the default Excel format for the given dtype. (This is overridden on a per-column basis by ``column_formats``). It is also valid to use dtype groups such as ``polars.datatypes.FLOAT_DTYPES`` as the dtype/format key, to simplify setting uniform int/float formats. + sparklines + A {"col":colnames,} or {"col":params,} dict that defines one or more + sparklines to be written into a new column in the table. If passing a + list of colnames (used as the source of the sparkline data) the default + sparkline settings are used (eg: will be a line with no markers). For more + control an xlsxwriter-compliant parameter dictionary can be supplied; in + this case three additional polars-specific keys are available: "columns", + "insert_before", and "insert_after". These allow you to define the source + columns and position the sparkline(s) with respect to other table columns. + If no position directive is given, sparklines are added to the end of the + table in the order in which they are defined (eg: to the far right). float_precision Default number of decimals displayed for floating point columns (note that this is purely a formatting directive; the actual values are not rounded). @@ -2472,11 +2486,19 @@ def write_excel( Notes ----- - All conditional formatting dictionaries should provide xlsxwriter-compatible + Conditional formatting parameter dicts should provide xlsxwriter-compatible definitions; polars will take care of how/where they are applied on the - worksheet with respect to the column position. For more details, see: + worksheet with respect to the column position. For supported options, see: https://xlsxwriter.readthedocs.io/working_with_conditional_formats.html + Similarly for sparklines, any parameter definition dictionary should contain + xlsxwriter-compatible key/values, as well as a mandatory polars "columns" key + that defines the sparkline source data; these source cols should be adjacent to + each other. Two other polars-specific keys are available to help define where + the sparkline appears in the table: "insert_after", and "insert_before". The + value associated with these keys should be the name of a column in the table. + https://xlsxwriter.readthedocs.io/working_with_sparklines.html + Examples -------- >>> from random import uniform @@ -2521,7 +2543,7 @@ def write_excel( ... position=(3, 1), # specify position as (row,col) coordinates ... conditional_formats={"num": "3_color_scale", "val": "data_bar"}, ... table_style="Table Style Medium 4", - ... ) # doctest: +IGNORE_RESULT + ... ) ... ... # advanced conditional formatting, custom styles ... df.write_excel( @@ -2566,6 +2588,36 @@ def write_excel( ... ws.write(len(df) + 6, 1, "Customised conditional formatting", fmt_title) ... + Export a table containing two different types of sparklines. Use default + options for the "trend" sparkline and customised options (and positioning) + for the "+/-" win_loss sparkline, with non-default integer dtype formatting, + column totals, and hidden worksheet gridlines: + + >>> df = pl.DataFrame( + ... { + ... "id": ["aaa", "bbb", "ccc", "ddd", "eee"], + ... "q1": [100, 55, -20, 0, 35], + ... "q2": [30, -10, 15, 60, 20], + ... "q3": [-50, 0, 40, 80, 80], + ... "q4": [75, 55, 25, -10, -55], + ... } + ... ) + >>> from polars.datatypes import INTEGER_DTYPES + >>> df.write_excel( # doctest: +SKIP + ... table_style="Table Style Light 2", + ... dtype_formats={INTEGER_DTYPES: "#,##0_);(#,##0)"}, + ... sparklines={ + ... "trend": ["q1", "q2", "q3", "q4"], + ... "+/-": { + ... "columns": ["q1", "q2", "q3", "q4"], + ... "insert_after": "id", + ... "type": "win_loss", + ... }, + ... }, + ... column_totals=["q1", "q2", "q3", "q4"], + ... hide_gridlines=True, + ... ) + """ try: import xlsxwriter @@ -2581,13 +2633,14 @@ def write_excel( # setup table format/columns table_style, table_options = _xl_setup_table_options(table_style) - table_columns = _xl_setup_table_columns( + table_columns, df = _xl_setup_table_columns( df=df, wb=wb, column_formats=column_formats, column_totals=column_totals, dtype_formats=dtype_formats, float_precision=float_precision, + sparklines=sparklines, ) # normalise cell refs (eg: "B3" => (2,1)) and establish table start/finish, @@ -2650,6 +2703,10 @@ def write_excel( elif options: ws.set_column(col_idx, col_idx, None, None, options) + # inject any sparklines into the table + for col, params in (sparklines or {}).items(): + _xl_inject_sparklines(ws, df, table_start, col, has_header, params) + if can_close: wb.close() return wb diff --git a/py-polars/polars/internals/io_excel.py b/py-polars/polars/internals/io_excel.py index 6bb759eb0dc1..8c1d91829a85 100644 --- a/py-polars/polars/internals/io_excel.py +++ b/py-polars/polars/internals/io_excel.py @@ -8,6 +8,7 @@ Sequence, ) +import polars.internals as pli from polars.datatypes import ( FLOAT_DTYPES, INTEGER_DTYPES, @@ -16,12 +17,12 @@ Datetime, Time, ) +from polars.exceptions import DuplicateError if TYPE_CHECKING: from xlsxwriter import Workbook from xlsxwriter.worksheet import Worksheet - import polars.internals as pli from polars.datatypes import OneOrMoreDataTypes, PolarsDataType @@ -36,6 +37,83 @@ _XL_DEFAULT_DTYPE_FORMATS_[tp] = _XL_DEFAULT_INTEGER_FORMAT_ +def _xl_inject_dummy_table_columns( + df: pli.DataFrame, options: dict[str, Sequence[str] | dict[str, Any]] +) -> pli.DataFrame: + """Insert dummy frame columns in order to create empty/named table columns.""" + df_original_columns = set(df.columns) + df_select_cols = df.columns.copy() + + for col, definition in options.items(): + if col in df_original_columns: + raise DuplicateError(f"Cannot create a second {col!r} column") + elif not isinstance(definition, dict): + df_select_cols.append(col) + else: + insert_after = definition.get("insert_after") + insert_before = definition.get("insert_before") + if insert_after is None and insert_before is None: + df_select_cols.append(col) + else: + insert_idx = ( + df_select_cols.index(insert_after) + 1 # type: ignore[arg-type] + if insert_before is None + else df_select_cols.index(insert_before) + ) + df_select_cols.insert(insert_idx, col) + + df = df.select( + [ + (col if col in df_original_columns else pli.lit("").alias(col)) + for col in df_select_cols + ] + ) + return df + + +def _xl_inject_sparklines( + ws: Worksheet, + df: pli.DataFrame, + table_start: tuple[int, int], + col: str, + has_header: bool, + params: Sequence[str] | dict[str, Any], +) -> None: + """Inject sparklines into (previously-created) empty table columns.""" + from xlsxwriter.utility import xl_rowcol_to_cell + + data_cols = params.get("columns") if isinstance(params, dict) else params + if not data_cols: + raise ValueError("Supplying 'columns' is mandatory for sparklines") + + data_idxs = sorted(df.find_idx_by_name(col) for col in data_cols) + if data_idxs != sorted(range(min(data_idxs), max(data_idxs) + 1)): + raise RuntimeError("sparkline data range/cols must be contiguous") + + spk_row, spk_col, _, _ = _xl_column_range(df, table_start, col, has_header) + data_start_col = table_start[1] + data_idxs[0] + data_end_col = table_start[1] + data_idxs[-1] + + if not isinstance(params, dict): + options = {} + else: + # strip polars-specific params before passing to xlsxwriter + options = { + name: val + for name, val in params.items() + if name not in ("columns", "insert_after", "insert_before") + } + if "negative_points" not in options: + options["negative_points"] = options.get("type") in ("column", "win_loss") + + for _ in range(len(df)): + data_start = xl_rowcol_to_cell(spk_row, data_start_col) + data_end = xl_rowcol_to_cell(spk_row, data_end_col) + options["range"] = f"{data_start}:{data_end}" + ws.add_sparkline(spk_row, spk_col, options) + spk_row += 1 + + def _xl_setup_workbook( workbook: Workbook | BytesIO | Path | str | None, worksheet: str | None = None ) -> tuple[Workbook, Worksheet, bool]: @@ -74,8 +152,9 @@ def _xl_setup_table_columns( column_formats: dict[str, str] | None = None, column_totals: dict[str, str] | Sequence[str] | bool | None = None, dtype_formats: dict[OneOrMoreDataTypes, str] | None = None, + sparklines: dict[str, Sequence[str] | dict[str, Any]] | None = None, float_precision: int = 3, -) -> list[dict[str, Any]]: +) -> tuple[list[dict[str, Any]], pli.DataFrame]: """Setup and unify all column-related formatting/defaults.""" total_funcs = ( {col: "sum" for col in column_totals} @@ -88,6 +167,10 @@ def _xl_setup_table_columns( if isinstance(tp, (tuple, frozenset)): dtype_formats.update(dict.fromkeys(tp, dtype_formats.pop(tp))) + # inject sparkline placeholder(s) + if sparklines: + df = _xl_inject_dummy_table_columns(df, sparklines) + # default float format zeros = "0" * float_precision fmt_float = ( @@ -125,7 +208,8 @@ def _xl_setup_table_columns( fmt["num_format"] = dtype_formats[tp] column_formats[col] = wb.add_format(fmt) - return [ + # assemble table columns + table_columns = [ { "header": col, "format": column_formats.get(col), @@ -133,6 +217,7 @@ def _xl_setup_table_columns( } for col in df.columns ] + return table_columns, df def _xl_setup_table_options( diff --git a/py-polars/tests/unit/io/test_excel.py b/py-polars/tests/unit/io/test_excel.py index 5154029f2607..68cb47f64007 100644 --- a/py-polars/tests/unit/io/test_excel.py +++ b/py-polars/tests/unit/io/test_excel.py @@ -8,7 +8,7 @@ import pytest import polars as pl -from polars.datatypes import FLOAT_DTYPES +from polars.datatypes import FLOAT_DTYPES, INTEGER_DTYPES from polars.testing import assert_frame_equal @@ -120,3 +120,57 @@ def test_excel_round_trip(write_params: dict[str, Any]) -> None: xldf = xldf[:3] assert_frame_equal(df, xldf) + + +def test_excel_sparklines() -> None: + from xlsxwriter import Workbook + + # note that we don't (quite) expect sparkline export to round-trip + # as we have to inject additional empty columns to hold them... + df = pl.DataFrame( + { + "id": ["aaa", "bbb", "ccc", "ddd", "eee"], + "q1": [100, 55, -20, 0, 35], + "q2": [30, -10, 15, 60, 20], + "q3": [-50, 0, 40, 80, 80], + "q4": [75, 55, 25, -10, -55], + } + ) + + # also: confirm that we can use a Workbook directly with "write_excel" + xls = BytesIO() + with Workbook(xls) as wb: + df.write_excel( + workbook=wb, + worksheet="frame_data", + table_style="Table Style Light 2", + dtype_formats={INTEGER_DTYPES: "#,##0_);(#,##0)"}, + sparklines={ + "trend": ["q1", "q2", "q3", "q4"], + "+/-": { + "columns": ["q1", "q2", "q3", "q4"], + "insert_after": "id", + "type": "win_loss", + }, + }, + hide_gridlines=True, + ) + + xldf = pl.read_excel(file=xls, sheet_name="frame_data") # type: ignore[call-overload] + # ┌──────┬──────┬─────┬─────┬─────┬─────┬───────┐ + # │ id ┆ +/- ┆ q1 ┆ q2 ┆ q3 ┆ q4 ┆ trend │ + # │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + # │ str ┆ str ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ str │ + # ╞══════╪══════╪═════╪═════╪═════╪═════╪═══════╡ + # │ aaa ┆ null ┆ 100 ┆ 30 ┆ -50 ┆ 75 ┆ null │ + # │ bbb ┆ null ┆ 55 ┆ -10 ┆ 0 ┆ 55 ┆ null │ + # │ ccc ┆ null ┆ -20 ┆ 15 ┆ 40 ┆ 25 ┆ null │ + # │ ddd ┆ null ┆ 0 ┆ 60 ┆ 80 ┆ -10 ┆ null │ + # │ eee ┆ null ┆ 35 ┆ 20 ┆ 80 ┆ -55 ┆ null │ + # └──────┴──────┴─────┴─────┴─────┴─────┴───────┘ + + for sparkline_col in ("+/-", "trend"): + assert set(xldf[sparkline_col]) == {None} + + assert xldf.columns == ["id", "+/-", "q1", "q2", "q3", "q4", "trend"] + assert_frame_equal(df, xldf.drop("+/-", "trend"))