Skip to content

Commit

Permalink
🏷️ types(lax): fix pyi type imports
Browse files Browse the repository at this point in the history
Signed-off-by: Nathaniel Starkman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jan 20, 2025
1 parent 1c9e9d5 commit b3f2e7c
Show file tree
Hide file tree
Showing 8 changed files with 722 additions and 671 deletions.
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
"jax>=0.4.3",
"jaxlib>=0.4.3",
"jaxtyping>=0.2.34",
"optype>=0.8.0",
"plum-dispatch>=2.5.2",
"quax>=0.0.5",
]
]

[project.urls]
Homepage = "https://github.com/GalacticDynamics/quaxed"
Expand Down Expand Up @@ -170,6 +171,10 @@
"TD003", # Missing issue link on the line following this TODO
]

[tool.ruff.lint.isort]
combine-as-imports = true
extra-standard-library = ["typing_extensions"]

[tool.ruff.lint.per-file-ignores]
"src/quaxed/**" = ["A004"]
"tests/**" = ["ANN", "INP001", "PLR0913", "PLR2004", "S101", "T20", "TID252"]
Expand Down
346 changes: 185 additions & 161 deletions src/quaxed/lax/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,178 +1,202 @@
# ----- Operators -----
# isort: split
from jax.lax.linalg import abs as abs
from jax.lax.linalg import acos as acos
from jax.lax.linalg import acosh as acosh
from jax.lax.linalg import add as add
from jax.lax.linalg import approx_max_k as approx_max_k
from jax.lax.linalg import approx_min_k as approx_min_k
from jax.lax.linalg import argmax as argmax
from jax.lax.linalg import argmin as argmin
from jax.lax.linalg import asin as asin
from jax.lax.linalg import asinh as asinh
from jax.lax.linalg import atan as atan
from jax.lax.linalg import atan2 as atan2
from jax.lax.linalg import atanh as atanh
from jax.lax.linalg import batch_matmul as batch_matmul
from jax.lax.linalg import bessel_i0e as bessel_i0e
from jax.lax.linalg import bessel_i1e as bessel_i1e
from jax.lax.linalg import betainc as betainc
from jax.lax.linalg import bitcast_convert_type as bitcast_convert_type
from jax.lax.linalg import bitwise_and as bitwise_and
from jax.lax.linalg import bitwise_not as bitwise_not
from jax.lax.linalg import bitwise_or as bitwise_or
from jax.lax.linalg import bitwise_xor as bitwise_xor
from jax.lax.linalg import broadcast as broadcast
from jax.lax.linalg import broadcast_in_dim as broadcast_in_dim
from jax.lax.linalg import broadcast_shapes as broadcast_shapes
from jax.lax.linalg import broadcast_to_rank as broadcast_to_rank
from jax.lax.linalg import broadcasted_iota as broadcasted_iota
from jax.lax.linalg import cbrt as cbrt
from jax.lax.linalg import ceil as ceil
from jax.lax.linalg import clamp as clamp
from jax.lax.linalg import clz as clz
from jax.lax.linalg import collapse as collapse
from jax.lax.linalg import complex as complex
from jax.lax.linalg import concatenate as concatenate
from jax.lax.linalg import conj as conj
from jax.lax.linalg import conv as conv
from jax.lax.linalg import conv_dimension_numbers as conv_dimension_numbers
from jax.lax.linalg import conv_general_dilated as conv_general_dilated
from jax.lax.linalg import conv_general_dilated_local as conv_general_dilated_local
from jax.lax.linalg import conv_general_dilated_patches as conv_general_dilated_patches
from jax.lax.linalg import conv_transpose as conv_transpose
from jax.lax.linalg import conv_with_general_padding as conv_with_general_padding
from jax.lax.linalg import convert_element_type as convert_element_type
from jax.lax.linalg import cos as cos
from jax.lax.linalg import cosh as cosh
from jax.lax.linalg import cumlogsumexp as cumlogsumexp
from jax.lax.linalg import cummax as cummax
from jax.lax.linalg import cummin as cummin
from jax.lax.linalg import cumprod as cumprod
from jax.lax.linalg import cumsum as cumsum
from jax.lax.linalg import digamma as digamma
from jax.lax.linalg import div as div
from jax.lax.linalg import dot as dot
from jax.lax.linalg import dot_general as dot_general
from jax.lax.linalg import dynamic_index_in_dim as dynamic_index_in_dim
from jax.lax.linalg import dynamic_slice as dynamic_slice
from jax.lax.linalg import dynamic_slice_in_dim as dynamic_slice_in_dim
from jax.lax.linalg import dynamic_update_index_in_dim as dynamic_update_index_in_dim
from jax.lax.linalg import dynamic_update_slice as dynamic_update_slice
from jax.lax.linalg import dynamic_update_slice_in_dim as dynamic_update_slice_in_dim
from jax.lax.linalg import eq as eq
from jax.lax.linalg import erf as erf
from jax.lax.linalg import erf_inv as erf_inv
from jax.lax.linalg import erfc as erfc
from jax.lax.linalg import exp as exp
from jax.lax.linalg import expand_dims as expand_dims
from jax.lax.linalg import expm1 as expm1
from jax.lax.linalg import fft as fft
from jax.lax.linalg import floor as floor
from jax.lax.linalg import full as full
from jax.lax.linalg import full_like as full_like
from jax.lax.linalg import gather as gather
from jax.lax.linalg import ge as ge
from jax.lax.linalg import gt as gt
from jax.lax.linalg import igamma as igamma
from jax.lax.linalg import igammac as igammac
from jax.lax.linalg import imag as imag
from jax.lax.linalg import index_in_dim as index_in_dim
from jax.lax.linalg import index_take as index_take
from jax.lax.linalg import integer_pow as integer_pow
from jax.lax.linalg import iota as iota
from jax.lax.linalg import is_finite as is_finite
from jax.lax.linalg import le as le
from jax.lax.linalg import lgamma as lgamma
from jax.lax.linalg import log as log
from jax.lax.linalg import log1p as log1p
from jax.lax.linalg import logistic as logistic
from jax.lax.linalg import lt as lt
from jax.lax.linalg import max as max
from jax.lax.linalg import min as min
from jax.lax.linalg import mul as mul
from jax.lax.linalg import neg as neg
from jax.lax.linalg import nextafter as nextafter
from jax.lax.linalg import pad as pad
from jax.lax.linalg import polygamma as polygamma
from jax.lax.linalg import population_count as population_count
from jax.lax.linalg import pow as pow
from jax.lax.linalg import random_gamma_grad as random_gamma_grad
from jax.lax.linalg import real as real
from jax.lax.linalg import reciprocal as reciprocal
from jax.lax.linalg import reduce as reduce
from jax.lax.linalg import reduce_precision as reduce_precision
from jax.lax.linalg import reduce_window as reduce_window
from jax.lax.linalg import rem as rem
from jax.lax.linalg import reshape as reshape
from jax.lax.linalg import rev as rev
from jax.lax.linalg import rng_bit_generator as rng_bit_generator
from jax.lax.linalg import rng_uniform as rng_uniform
from jax.lax.linalg import round as round
from jax.lax.linalg import rsqrt as rsqrt
from jax.lax.linalg import scatter as scatter
from jax.lax.linalg import scatter_add as scatter_add
from jax.lax.linalg import scatter_apply as scatter_apply
from jax.lax.linalg import scatter_max as scatter_max
from jax.lax.linalg import scatter_min as scatter_min
from jax.lax.linalg import scatter_mul as scatter_mul
from jax.lax.linalg import shift_left as shift_left
from jax.lax.linalg import shift_right_arithmetic as shift_right_arithmetic
from jax.lax.linalg import shift_right_logical as shift_right_logical
from jax.lax.linalg import sign as sign
from jax.lax.linalg import sin as sin
from jax.lax.linalg import sinh as sinh
from jax.lax.linalg import slice as slice
from jax.lax.linalg import slice_in_dim as slice_in_dim
from jax.lax.linalg import sort as sort
from jax.lax.linalg import sort_key_val as sort_key_val
from jax.lax.linalg import sqrt as sqrt
from jax.lax.linalg import square as square
from jax.lax.linalg import squeeze as squeeze
from jax.lax.linalg import sub as sub
from jax.lax.linalg import tan as tan
from jax.lax.linalg import tanh as tanh
from jax.lax.linalg import top_k as top_k
from jax.lax.linalg import transpose as transpose
from jax.lax.linalg import zeros_like_array as zeros_like_array
from jax.lax.linalg import zeta as zeta
from jax.lax import (
abs as abs,
acos as acos,
acosh as acosh,
add as add,
approx_max_k as approx_max_k,
approx_min_k as approx_min_k,
argmax as argmax,
argmin as argmin,
asin as asin,
asinh as asinh,
atan as atan,
atan2 as atan2,
atanh as atanh,
batch_matmul as batch_matmul,
bessel_i0e as bessel_i0e,
bessel_i1e as bessel_i1e,
betainc as betainc,
bitcast_convert_type as bitcast_convert_type,
bitwise_and as bitwise_and,
bitwise_not as bitwise_not,
bitwise_or as bitwise_or,
bitwise_xor as bitwise_xor,
broadcast as broadcast,
broadcast_in_dim as broadcast_in_dim,
broadcast_shapes as broadcast_shapes,
broadcast_to_rank as broadcast_to_rank,
broadcasted_iota as broadcasted_iota,
cbrt as cbrt,
ceil as ceil,
clamp as clamp,
clz as clz,
collapse as collapse,
complex as complex,
concatenate as concatenate,
conj as conj,
conv as conv,
conv_dimension_numbers as conv_dimension_numbers,
conv_general_dilated as conv_general_dilated,
conv_general_dilated_local as conv_general_dilated_local,
conv_general_dilated_patches as conv_general_dilated_patches,
conv_transpose as conv_transpose,
conv_with_general_padding as conv_with_general_padding,
convert_element_type as convert_element_type,
cos as cos,
cosh as cosh,
cumlogsumexp as cumlogsumexp,
cummax as cummax,
cummin as cummin,
cumprod as cumprod,
cumsum as cumsum,
digamma as digamma,
div as div,
dot as dot,
dot_general as dot_general,
dynamic_index_in_dim as dynamic_index_in_dim,
dynamic_slice as dynamic_slice,
dynamic_slice_in_dim as dynamic_slice_in_dim,
dynamic_update_index_in_dim as dynamic_update_index_in_dim,
dynamic_update_slice as dynamic_update_slice,
dynamic_update_slice_in_dim as dynamic_update_slice_in_dim,
eq as eq,
erf as erf,
erf_inv as erf_inv,
erfc as erfc,
exp as exp,
expand_dims as expand_dims,
expm1 as expm1,
fft as fft,
floor as floor,
full as full,
full_like as full_like,
gather as gather,
ge as ge,
gt as gt,
igamma as igamma,
igammac as igammac,
imag as imag,
index_in_dim as index_in_dim,
index_take as index_take,
integer_pow as integer_pow,
iota as iota,
is_finite as is_finite,
le as le,
lgamma as lgamma,
log as log,
log1p as log1p,
logistic as logistic,
lt as lt,
max as max,
min as min,
mul as mul,
ne as ne,
neg as neg,
nextafter as nextafter,
pad as pad,
polygamma as polygamma,
population_count as population_count,
pow as pow,
random_gamma_grad as random_gamma_grad,
real as real,
reciprocal as reciprocal,
reduce as reduce,
reduce_precision as reduce_precision,
reduce_window as reduce_window,
rem as rem,
reshape as reshape,
rev as rev,
rng_bit_generator as rng_bit_generator,
rng_uniform as rng_uniform,
round as round,
rsqrt as rsqrt,
scatter as scatter,
scatter_add as scatter_add,
scatter_apply as scatter_apply,
scatter_max as scatter_max,
scatter_min as scatter_min,
scatter_mul as scatter_mul,
shift_left as shift_left,
shift_right_arithmetic as shift_right_arithmetic,
shift_right_logical as shift_right_logical,
sign as sign,
sin as sin,
sinh as sinh,
slice as slice,
slice_in_dim as slice_in_dim,
sort as sort,
sort_key_val as sort_key_val,
sqrt as sqrt,
square as square,
squeeze as squeeze,
sub as sub,
tan as tan,
tanh as tanh,
top_k as top_k,
transpose as transpose,
zeros_like_array as zeros_like_array,
zeta as zeta,
)

# ----- Control Flow Operators -----
# isort: split
from jax.lax.linalg import associative_scan as associative_scan
from jax.lax.linalg import cond as cond
from jax.lax.linalg import fori_loop as fori_loop
from jax.lax.linalg import map as map
from jax.lax.linalg import scan as scan
from jax.lax.linalg import select as select
from jax.lax.linalg import select_n as select_n
from jax.lax.linalg import switch as switch
from jax.lax.linalg import while_loop as while_loop
from jax.lax import (
associative_scan as associative_scan,
cond as cond,
fori_loop as fori_loop,
map as map,
scan as scan,
select as select,
select_n as select_n,
switch as switch,
while_loop as while_loop,
)

# ----- Custom Gradient Operators -----
# isort: split
from jax.lax.linalg import custom_linear_solve as custom_linear_solve
from jax.lax.linalg import custom_root as custom_root
from jax.lax.linalg import stop_gradient as stop_gradient
from jax.lax import (
custom_linear_solve as custom_linear_solve,
custom_root as custom_root,
stop_gradient as stop_gradient,
)

# ----- Parallel Operators -----
# isort: split
from jax.lax.linalg import all_gather as all_gather
from jax.lax.linalg import all_to_all as all_to_all
from jax.lax.linalg import axis_index as axis_index
from jax.lax.linalg import pmax as pmax
from jax.lax.linalg import pmean as pmean
from jax.lax.linalg import pmin as pmin
from jax.lax.linalg import ppermute as ppermute
from jax.lax.linalg import pshuffle as pshuffle
from jax.lax.linalg import psum as psum
from jax.lax.linalg import psum_scatter as psum_scatter
from jax.lax.linalg import pswapaxes as pswapaxes
from jax.lax import (
all_gather as all_gather,
all_to_all as all_to_all,
axis_index as axis_index,
pmax as pmax,
pmean as pmean,
pmin as pmin,
ppermute as ppermute,
pshuffle as pshuffle,
psum as psum,
psum_scatter as psum_scatter,
pswapaxes as pswapaxes,
)

# ----- Sharding-related Operators -----
# isort: split
from jax.lax.linalg import with_sharding_constraint as with_sharding_constraint
from jax.lax import with_sharding_constraint as with_sharding_constraint

# ----- Linear Algebra Operators -----
# isort: split
from jax.lax.linalg import linalg as linalg
from jax.lax import linalg as linalg

# ----- Argument classes -----
# isort: split
from jax.lax import (
ConvDimensionNumbers as ConvDimensionNumbers,
ConvGeneralDilatedDimensionNumbers as ConvGeneralDilatedDimensionNumbers,
DotAlgorithm as DotAlgorithm,
GatherDimensionNumbers as GatherDimensionNumbers,
GatherScatterMode as GatherScatterMode,
Precision as Precision,
PrecisionLike as PrecisionLike,
RandomAlgorithm as RandomAlgorithm,
RoundingMethod as RoundingMethod,
ScatterDimensionNumbers as ScatterDimensionNumbers,
)
Loading

0 comments on commit b3f2e7c

Please sign in to comment.