Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ComputeError: cannot aggregate a literal for literals in .over() #16343

Open
2 tasks done
kevinli1993 opened this issue May 20, 2024 · 11 comments
Open
2 tasks done

ComputeError: cannot aggregate a literal for literals in .over() #16343

kevinli1993 opened this issue May 20, 2024 · 11 comments
Labels
invalid A bug report that is not actually a bug python Related to Python Polars

Comments

@kevinli1993
Copy link

Checks

  • I have checked that this issue has not already been reported.
  • I have confirmed this bug exists on the latest version of Polars.

Reproducible example

import polars as pl
ds = pl.DataFrame(dict(A=[1,1,1,2,2,3]))
ds.with_columns(pl.lit(1).alias("B").sum().over(pl.col("A")))   # This will trigger a compute error

ds.with_columns(pl.repeat(1, pl.len()).alias("B").sum().over(pl.col("A"))) # this works

The error is

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File ".../.venv/lib/python3.12/site-packages/polars/dataframe/frame.py", line 8310, in with_columns
    return self.lazy().with_columns(*exprs, **named_exprs).collect(_eager=True)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/polars/lazyframe/frame.py", line 1816, in collect
    return wrap_df(ldf.collect(callback))
                   ^^^^^^^^^^^^^^^^^^^^^
polars.exceptions.ComputeError: cannot aggregate a literal

Log output

No response

Issue description

This is similar to this bug.

It seems sensible to allow literals in an .over() context, to allow for either pl.col("...") or pl.lit(1) for user defined functions.

Expected behavior

The behavior should be the same as using

ds.with_columns(pl.repeat(1, pl.len()).alias("B").sum().over(pl.col("A"))) # this works

Installed versions

--------Version info---------
Polars:               0.20.26
Index type:           UInt32
Platform:             macOS-14.4.1-arm64-arm-64bit
Python:               3.12.3 (main, Apr 12 2024, 17:16:04) [Clang 15.0.0 (clang-1500.1.0.2.5)]

----Optional dependencies----
adbc_driver_manager:  <not installed>
cloudpickle:          <not installed>
connectorx:           <not installed>
deltalake:            <not installed>
fastexcel:            <not installed>
fsspec:               <not installed>
gevent:               <not installed>
hvplot:               <not installed>
matplotlib:           3.8.4
nest_asyncio:         <not installed>
numpy:                1.26.4
openpyxl:             <not installed>
pandas:               2.2.2
pyarrow:              16.1.0
pydantic:             <not installed>
pyiceberg:            <not installed>
pyxlsb:               <not installed>
sqlalchemy:           <not installed>
torch:                <not installed>
xlsx2csv:             <not installed>
xlsxwriter:           <not installed>```

</details>
@kevinli1993 kevinli1993 added bug Something isn't working needs triage Awaiting prioritization by a maintainer python Related to Python Polars labels May 20, 2024
@Julian-J-S
Copy link
Contributor

Julian-J-S commented May 21, 2024

Do you have a "real-world-example" where this is useful?

Seems like you want to know the number of rows per group?

No computer access currently but this might work and is very easy:
pl.len().over("A")

@ritchie46
Copy link
Member

pl.len().over("A")

This works indeed.

We do allow literals in over. We don't allow the summation of literals during group by. Which is a good decision as you already expected this to work differently that it would. pl.lit(1).sum() would return 1, not very useful.

@ritchie46 ritchie46 added invalid A bug report that is not actually a bug and removed bug Something isn't working needs triage Awaiting prioritization by a maintainer labels May 21, 2024
@kevinli1993
Copy link
Author

Here is a non-contrived use-case I have. I want to define this function

def grouped_weighted_mean(grouping, x, weight):
    return ((x * weight).sum() / weight.sum()).over(grouping)

So the user may do:

(
    ds
    .with_columns(
        grouped_weighted_mean(pl.col("G"), pl.col("x"), pl.lit(1.0)).alias("XW0"),
        grouped_weighted_mean(pl.col("G"), pl.col("x"), pl.col("weight_one")).alias("XW1"),
        grouped_weighted_mean(pl.col("G"), pl.col("x"), pl.col("weight_one").pow(2)).alias("XW2")
    )
)

Here, the first line does not work due to pl.lit(1.0) being a literal.

@ritchie46, I understand now that pl.lit(1).sum() won't work as expected, but curious if you have any suggestions for the above use-case? The idea is to keep the function interface consistent, e.g., so the user could use an existing column/expression, or a literal for a constant weight.

Thanks!

@cmdlineluser
Copy link
Contributor

I thought something in .meta could help:

>>> pl.col("x").meta.is_column()
True

But I guess it just checks for {"Column":...}

>>> (pl.col("x") + 1).meta.is_column()
False

I'm not sure if there is a way to test if something consists only of literals?

A single is_literal would be easy:

>>> pl.lit(1.0).meta.serialize()
'{"Literal":{"Float":1.0}}'

But perhaps not:

>>> (pl.lit(1.0) + pl.lit(2.0)).meta.serialize()
'{"BinaryExpr":{"left":{"Literal":{"Float":1.0}},"op":"Plus","right":{"Literal":{"Float":2.0}}}}'

@cjackal
Copy link
Contributor

cjackal commented May 21, 2024

So the user may do:

(
    ds
    .with_columns(
        grouped_weighted_mean(pl.col("G"), pl.col("x"), pl.lit(1.0)).alias("XW0"),
        grouped_weighted_mean(pl.col("G"), pl.col("x"), pl.col("weight_one")).alias("XW1"),
        grouped_weighted_mean(pl.col("G"), pl.col("x"), pl.col("weight_one").pow(2)).alias("XW2")
    )
)

Can the user just set pl.col("weight_one").pow(0) instead of pl.lit(1.0), if they are just interested in grouped exponentially weighted average?

@kevinli1993
Copy link
Author

kevinli1993 commented May 21, 2024

Can the user just set pl.col("weight_one").pow(0) instead of pl.lit(1.0), if they are just interested in grouped exponentially weighted average?

There are many workarounds (.pow(0) is one, using pl.col("const") where const is all 1.0 is another), but this is missing the point -- the question is whether we could use a literal in situations like this in general.

@cjackal
Copy link
Contributor

cjackal commented May 21, 2024

There are many workarounds (.pow(0) is one, using pl.col("const") where const is all 1.0 is another), but this is missing the point -- the question is whether we could use a literal in situations like this in general.

Sounds quite like a riddle, how about applying a tautological identity of existing columns like weight + pl.col("A") - pl.col("A")? I mean, making a workaround for a literal is generally harder than forcing it non-literal (as is implicit in @cmdlineluser 's comment)

@kevinli1993
Copy link
Author

Yep! For now I'm just passing the burden/choice onto the user; if they want to use a constant weight, they will either need to do pl.col("W") * 0 + 1, or use pl.col("const") (assuming they have a constant column available in their dataframe).

Neither is as aesthetically pleasing as pl.lit(1.0), imo :-)

@cmdlineluser
Copy link
Contributor

I just noticed an interesting PR for meta.is_column_selection() - #16479

Perhaps a meta.is_literal() isn't as difficult as I initially thought?

Can we say if none of the Column-type nodes are contained anywhere in the expression then it is a Literal?

demo.py
def is_literal(expr):
    non_literal = {
        "Column", "Columns", "DtypeColumn", "Exclude",
        "IndexColumn", "Nth", "Selector", "Wildcard"
    }
    
    found = False
    
    def _is_literal_impl(obj):
        nonlocal found
        if found: return
        
        found = bool(obj.keys() & non_literal)
        if found: return
        
        return obj
        
    __import__("json").loads(
        expr.meta.serialize(), 
        object_hook=_is_literal_impl
    )
    
    return not found
>>> is_literal(pl.lit(1.0) + pl.lit(2) + pl.col("foo"))
False
>>> is_literal(pl.lit(1.0) + pl.lit(2))
True

Not sure if there are any other possible cases.

@ritchie46
Copy link
Member

Here, the first line does not work due to pl.lit(1.0) being a literal.

The caller shouldn't pass a literal, but something of the groups length: pl.ones(pl.len()).

@cmdlineluser
Copy link
Contributor

cmdlineluser commented Nov 14, 2024

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
invalid A bug report that is not actually a bug python Related to Python Polars
Projects
None yet
Development

No branches or pull requests

5 participants