Skip to content

Commit

Permalink
Adding pre-commit (for linting) to test dependencies. (#671)
Browse files Browse the repository at this point in the history
* Adding pre-commit (for linting) to test dependencies.

* Running pre-commit in CI via pre-built action.

* Change name of pre-commit hook. Run pre-commit.

* Fixing formatting.
  • Loading branch information
alxmrs authored Feb 11, 2025
1 parent 4f8ed38 commit afd519a
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
20 changes: 20 additions & 0 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
name: pre-commit

on:
push:
branches: [main]
pull_request:
branches: [main]

jobs:
checks:
name: pre-commit hooks
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4.1.5

- uses: actions/setup-python@v5
with:
python-version: '3.10'

- uses: pre-commit/action@v3.0.1
26 changes: 22 additions & 4 deletions cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def mean(x, /, *, axis=None, keepdims=False, split_every=None):
# pair of fields needed to keep per-chunk counts and totals for computing
# the mean.
dtype = x.dtype
#TODO(#658): Should these be default dtypes?
# TODO(#658): Should these be default dtypes?
intermediate_dtype = [("n", nxp.int64), ("total", nxp.float64)]
extra_func_kwargs = dict(dtype=intermediate_dtype)
return reduction(
Expand Down Expand Up @@ -110,7 +110,16 @@ def min(x, /, *, axis=None, keepdims=False, split_every=None):


def prod(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None, device=None):
dtype = _upcast_integral_dtypes(x, dtype, allowed_dtypes=("numeric", "boolean",), fname="prod", device=device)
dtype = _upcast_integral_dtypes(
x,
dtype,
allowed_dtypes=(
"numeric",
"boolean",
),
fname="prod",
device=device,
)
extra_func_kwargs = dict(dtype=dtype)
return reduction(
x,
Expand All @@ -136,7 +145,16 @@ def std(x, /, *, axis=None, correction=0.0, keepdims=False, split_every=None):


def sum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None, device=None):
dtype = _upcast_integral_dtypes(x, dtype, allowed_dtypes=("numeric", "boolean",), fname="sum", device=device)
dtype = _upcast_integral_dtypes(
x,
dtype,
allowed_dtypes=(
"numeric",
"boolean",
),
fname="sum",
device=device,
)
extra_func_kwargs = dict(dtype=dtype)
return reduction(
x,
Expand All @@ -163,7 +181,7 @@ def var(
if x.dtype not in _real_floating_dtypes:
raise TypeError("Only real floating-point dtypes are allowed in var")
dtype = x.dtype
#TODO(#658): Should these be default dtypes?
# TODO(#658): Should these be default dtypes?
intermediate_dtype = [("n", nxp.int64), ("mu", nxp.float64), ("M2", nxp.float64)]
extra_func_kwargs = dict(dtype=intermediate_dtype, correction=correction)
return reduction(
Expand Down

0 comments on commit afd519a

Please sign in to comment.