diff --git a/pyproject.toml b/pyproject.toml index 48684a5..693b43d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,9 +26,10 @@ "jax>=0.4.3", "jaxlib>=0.4.3", "jaxtyping>=0.2.34", + "optype>=0.8.0", "plum-dispatch>=2.5.2", "quax>=0.0.5", - ] +] [project.urls] Homepage = "https://github.com/GalacticDynamics/quaxed" @@ -170,6 +171,10 @@ "TD003", # Missing issue link on the line following this TODO ] + [tool.ruff.lint.isort] + combine-as-imports = true + extra-standard-library = ["typing_extensions"] + [tool.ruff.lint.per-file-ignores] "src/quaxed/**" = ["A004"] "tests/**" = ["ANN", "INP001", "PLR0913", "PLR2004", "S101", "T20", "TID252"] diff --git a/src/quaxed/lax/__init__.pyi b/src/quaxed/lax/__init__.pyi index a769d6c..ab650f7 100644 --- a/src/quaxed/lax/__init__.pyi +++ b/src/quaxed/lax/__init__.pyi @@ -1,178 +1,202 @@ # ----- Operators ----- # isort: split -from jax.lax.linalg import abs as abs -from jax.lax.linalg import acos as acos -from jax.lax.linalg import acosh as acosh -from jax.lax.linalg import add as add -from jax.lax.linalg import approx_max_k as approx_max_k -from jax.lax.linalg import approx_min_k as approx_min_k -from jax.lax.linalg import argmax as argmax -from jax.lax.linalg import argmin as argmin -from jax.lax.linalg import asin as asin -from jax.lax.linalg import asinh as asinh -from jax.lax.linalg import atan as atan -from jax.lax.linalg import atan2 as atan2 -from jax.lax.linalg import atanh as atanh -from jax.lax.linalg import batch_matmul as batch_matmul -from jax.lax.linalg import bessel_i0e as bessel_i0e -from jax.lax.linalg import bessel_i1e as bessel_i1e -from jax.lax.linalg import betainc as betainc -from jax.lax.linalg import bitcast_convert_type as bitcast_convert_type -from jax.lax.linalg import bitwise_and as bitwise_and -from jax.lax.linalg import bitwise_not as bitwise_not -from jax.lax.linalg import bitwise_or as bitwise_or -from jax.lax.linalg import bitwise_xor as bitwise_xor -from jax.lax.linalg import broadcast as broadcast -from jax.lax.linalg import broadcast_in_dim as broadcast_in_dim -from jax.lax.linalg import broadcast_shapes as broadcast_shapes -from jax.lax.linalg import broadcast_to_rank as broadcast_to_rank -from jax.lax.linalg import broadcasted_iota as broadcasted_iota -from jax.lax.linalg import cbrt as cbrt -from jax.lax.linalg import ceil as ceil -from jax.lax.linalg import clamp as clamp -from jax.lax.linalg import clz as clz -from jax.lax.linalg import collapse as collapse -from jax.lax.linalg import complex as complex -from jax.lax.linalg import concatenate as concatenate -from jax.lax.linalg import conj as conj -from jax.lax.linalg import conv as conv -from jax.lax.linalg import conv_dimension_numbers as conv_dimension_numbers -from jax.lax.linalg import conv_general_dilated as conv_general_dilated -from jax.lax.linalg import conv_general_dilated_local as conv_general_dilated_local -from jax.lax.linalg import conv_general_dilated_patches as conv_general_dilated_patches -from jax.lax.linalg import conv_transpose as conv_transpose -from jax.lax.linalg import conv_with_general_padding as conv_with_general_padding -from jax.lax.linalg import convert_element_type as convert_element_type -from jax.lax.linalg import cos as cos -from jax.lax.linalg import cosh as cosh -from jax.lax.linalg import cumlogsumexp as cumlogsumexp -from jax.lax.linalg import cummax as cummax -from jax.lax.linalg import cummin as cummin -from jax.lax.linalg import cumprod as cumprod -from jax.lax.linalg import cumsum as cumsum -from jax.lax.linalg import digamma as digamma -from jax.lax.linalg import div as div -from jax.lax.linalg import dot as dot -from jax.lax.linalg import dot_general as dot_general -from jax.lax.linalg import dynamic_index_in_dim as dynamic_index_in_dim -from jax.lax.linalg import dynamic_slice as dynamic_slice -from jax.lax.linalg import dynamic_slice_in_dim as dynamic_slice_in_dim -from jax.lax.linalg import dynamic_update_index_in_dim as dynamic_update_index_in_dim -from jax.lax.linalg import dynamic_update_slice as dynamic_update_slice -from jax.lax.linalg import dynamic_update_slice_in_dim as dynamic_update_slice_in_dim -from jax.lax.linalg import eq as eq -from jax.lax.linalg import erf as erf -from jax.lax.linalg import erf_inv as erf_inv -from jax.lax.linalg import erfc as erfc -from jax.lax.linalg import exp as exp -from jax.lax.linalg import expand_dims as expand_dims -from jax.lax.linalg import expm1 as expm1 -from jax.lax.linalg import fft as fft -from jax.lax.linalg import floor as floor -from jax.lax.linalg import full as full -from jax.lax.linalg import full_like as full_like -from jax.lax.linalg import gather as gather -from jax.lax.linalg import ge as ge -from jax.lax.linalg import gt as gt -from jax.lax.linalg import igamma as igamma -from jax.lax.linalg import igammac as igammac -from jax.lax.linalg import imag as imag -from jax.lax.linalg import index_in_dim as index_in_dim -from jax.lax.linalg import index_take as index_take -from jax.lax.linalg import integer_pow as integer_pow -from jax.lax.linalg import iota as iota -from jax.lax.linalg import is_finite as is_finite -from jax.lax.linalg import le as le -from jax.lax.linalg import lgamma as lgamma -from jax.lax.linalg import log as log -from jax.lax.linalg import log1p as log1p -from jax.lax.linalg import logistic as logistic -from jax.lax.linalg import lt as lt -from jax.lax.linalg import max as max -from jax.lax.linalg import min as min -from jax.lax.linalg import mul as mul -from jax.lax.linalg import neg as neg -from jax.lax.linalg import nextafter as nextafter -from jax.lax.linalg import pad as pad -from jax.lax.linalg import polygamma as polygamma -from jax.lax.linalg import population_count as population_count -from jax.lax.linalg import pow as pow -from jax.lax.linalg import random_gamma_grad as random_gamma_grad -from jax.lax.linalg import real as real -from jax.lax.linalg import reciprocal as reciprocal -from jax.lax.linalg import reduce as reduce -from jax.lax.linalg import reduce_precision as reduce_precision -from jax.lax.linalg import reduce_window as reduce_window -from jax.lax.linalg import rem as rem -from jax.lax.linalg import reshape as reshape -from jax.lax.linalg import rev as rev -from jax.lax.linalg import rng_bit_generator as rng_bit_generator -from jax.lax.linalg import rng_uniform as rng_uniform -from jax.lax.linalg import round as round -from jax.lax.linalg import rsqrt as rsqrt -from jax.lax.linalg import scatter as scatter -from jax.lax.linalg import scatter_add as scatter_add -from jax.lax.linalg import scatter_apply as scatter_apply -from jax.lax.linalg import scatter_max as scatter_max -from jax.lax.linalg import scatter_min as scatter_min -from jax.lax.linalg import scatter_mul as scatter_mul -from jax.lax.linalg import shift_left as shift_left -from jax.lax.linalg import shift_right_arithmetic as shift_right_arithmetic -from jax.lax.linalg import shift_right_logical as shift_right_logical -from jax.lax.linalg import sign as sign -from jax.lax.linalg import sin as sin -from jax.lax.linalg import sinh as sinh -from jax.lax.linalg import slice as slice -from jax.lax.linalg import slice_in_dim as slice_in_dim -from jax.lax.linalg import sort as sort -from jax.lax.linalg import sort_key_val as sort_key_val -from jax.lax.linalg import sqrt as sqrt -from jax.lax.linalg import square as square -from jax.lax.linalg import squeeze as squeeze -from jax.lax.linalg import sub as sub -from jax.lax.linalg import tan as tan -from jax.lax.linalg import tanh as tanh -from jax.lax.linalg import top_k as top_k -from jax.lax.linalg import transpose as transpose -from jax.lax.linalg import zeros_like_array as zeros_like_array -from jax.lax.linalg import zeta as zeta +from jax.lax import ( + abs as abs, + acos as acos, + acosh as acosh, + add as add, + approx_max_k as approx_max_k, + approx_min_k as approx_min_k, + argmax as argmax, + argmin as argmin, + asin as asin, + asinh as asinh, + atan as atan, + atan2 as atan2, + atanh as atanh, + batch_matmul as batch_matmul, + bessel_i0e as bessel_i0e, + bessel_i1e as bessel_i1e, + betainc as betainc, + bitcast_convert_type as bitcast_convert_type, + bitwise_and as bitwise_and, + bitwise_not as bitwise_not, + bitwise_or as bitwise_or, + bitwise_xor as bitwise_xor, + broadcast as broadcast, + broadcast_in_dim as broadcast_in_dim, + broadcast_shapes as broadcast_shapes, + broadcast_to_rank as broadcast_to_rank, + broadcasted_iota as broadcasted_iota, + cbrt as cbrt, + ceil as ceil, + clamp as clamp, + clz as clz, + collapse as collapse, + complex as complex, + concatenate as concatenate, + conj as conj, + conv as conv, + conv_dimension_numbers as conv_dimension_numbers, + conv_general_dilated as conv_general_dilated, + conv_general_dilated_local as conv_general_dilated_local, + conv_general_dilated_patches as conv_general_dilated_patches, + conv_transpose as conv_transpose, + conv_with_general_padding as conv_with_general_padding, + convert_element_type as convert_element_type, + cos as cos, + cosh as cosh, + cumlogsumexp as cumlogsumexp, + cummax as cummax, + cummin as cummin, + cumprod as cumprod, + cumsum as cumsum, + digamma as digamma, + div as div, + dot as dot, + dot_general as dot_general, + dynamic_index_in_dim as dynamic_index_in_dim, + dynamic_slice as dynamic_slice, + dynamic_slice_in_dim as dynamic_slice_in_dim, + dynamic_update_index_in_dim as dynamic_update_index_in_dim, + dynamic_update_slice as dynamic_update_slice, + dynamic_update_slice_in_dim as dynamic_update_slice_in_dim, + eq as eq, + erf as erf, + erf_inv as erf_inv, + erfc as erfc, + exp as exp, + expand_dims as expand_dims, + expm1 as expm1, + fft as fft, + floor as floor, + full as full, + full_like as full_like, + gather as gather, + ge as ge, + gt as gt, + igamma as igamma, + igammac as igammac, + imag as imag, + index_in_dim as index_in_dim, + index_take as index_take, + integer_pow as integer_pow, + iota as iota, + is_finite as is_finite, + le as le, + lgamma as lgamma, + log as log, + log1p as log1p, + logistic as logistic, + lt as lt, + max as max, + min as min, + mul as mul, + ne as ne, + neg as neg, + nextafter as nextafter, + pad as pad, + polygamma as polygamma, + population_count as population_count, + pow as pow, + random_gamma_grad as random_gamma_grad, + real as real, + reciprocal as reciprocal, + reduce as reduce, + reduce_precision as reduce_precision, + reduce_window as reduce_window, + rem as rem, + reshape as reshape, + rev as rev, + rng_bit_generator as rng_bit_generator, + rng_uniform as rng_uniform, + round as round, + rsqrt as rsqrt, + scatter as scatter, + scatter_add as scatter_add, + scatter_apply as scatter_apply, + scatter_max as scatter_max, + scatter_min as scatter_min, + scatter_mul as scatter_mul, + shift_left as shift_left, + shift_right_arithmetic as shift_right_arithmetic, + shift_right_logical as shift_right_logical, + sign as sign, + sin as sin, + sinh as sinh, + slice as slice, + slice_in_dim as slice_in_dim, + sort as sort, + sort_key_val as sort_key_val, + sqrt as sqrt, + square as square, + squeeze as squeeze, + sub as sub, + tan as tan, + tanh as tanh, + top_k as top_k, + transpose as transpose, + zeros_like_array as zeros_like_array, + zeta as zeta, +) # ----- Control Flow Operators ----- # isort: split -from jax.lax.linalg import associative_scan as associative_scan -from jax.lax.linalg import cond as cond -from jax.lax.linalg import fori_loop as fori_loop -from jax.lax.linalg import map as map -from jax.lax.linalg import scan as scan -from jax.lax.linalg import select as select -from jax.lax.linalg import select_n as select_n -from jax.lax.linalg import switch as switch -from jax.lax.linalg import while_loop as while_loop +from jax.lax import ( + associative_scan as associative_scan, + cond as cond, + fori_loop as fori_loop, + map as map, + scan as scan, + select as select, + select_n as select_n, + switch as switch, + while_loop as while_loop, +) # ----- Custom Gradient Operators ----- # isort: split -from jax.lax.linalg import custom_linear_solve as custom_linear_solve -from jax.lax.linalg import custom_root as custom_root -from jax.lax.linalg import stop_gradient as stop_gradient +from jax.lax import ( + custom_linear_solve as custom_linear_solve, + custom_root as custom_root, + stop_gradient as stop_gradient, +) # ----- Parallel Operators ----- # isort: split -from jax.lax.linalg import all_gather as all_gather -from jax.lax.linalg import all_to_all as all_to_all -from jax.lax.linalg import axis_index as axis_index -from jax.lax.linalg import pmax as pmax -from jax.lax.linalg import pmean as pmean -from jax.lax.linalg import pmin as pmin -from jax.lax.linalg import ppermute as ppermute -from jax.lax.linalg import pshuffle as pshuffle -from jax.lax.linalg import psum as psum -from jax.lax.linalg import psum_scatter as psum_scatter -from jax.lax.linalg import pswapaxes as pswapaxes +from jax.lax import ( + all_gather as all_gather, + all_to_all as all_to_all, + axis_index as axis_index, + pmax as pmax, + pmean as pmean, + pmin as pmin, + ppermute as ppermute, + pshuffle as pshuffle, + psum as psum, + psum_scatter as psum_scatter, + pswapaxes as pswapaxes, +) # ----- Sharding-related Operators ----- # isort: split -from jax.lax.linalg import with_sharding_constraint as with_sharding_constraint +from jax.lax import with_sharding_constraint as with_sharding_constraint # ----- Linear Algebra Operators ----- # isort: split -from jax.lax.linalg import linalg as linalg +from jax.lax import linalg as linalg + +# ----- Argument classes ----- +# isort: split +from jax.lax import ( + ConvDimensionNumbers as ConvDimensionNumbers, + ConvGeneralDilatedDimensionNumbers as ConvGeneralDilatedDimensionNumbers, + DotAlgorithm as DotAlgorithm, + GatherDimensionNumbers as GatherDimensionNumbers, + GatherScatterMode as GatherScatterMode, + Precision as Precision, + PrecisionLike as PrecisionLike, + RandomAlgorithm as RandomAlgorithm, + RoundingMethod as RoundingMethod, + ScatterDimensionNumbers as ScatterDimensionNumbers, +) diff --git a/src/quaxed/lax/linalg.pyi b/src/quaxed/lax/linalg.pyi index d7987c3..da2caca 100644 --- a/src/quaxed/lax/linalg.pyi +++ b/src/quaxed/lax/linalg.pyi @@ -1,13 +1,15 @@ -from jax.lax.linalg import cholesky as cholesky -from jax.lax.linalg import eig as eig -from jax.lax.linalg import eigh as eigh -from jax.lax.linalg import hessenberg as hessenberg -from jax.lax.linalg import householder_product as householder_product -from jax.lax.linalg import lu as lu -from jax.lax.linalg import qdwh as qdwh -from jax.lax.linalg import qr as qr -from jax.lax.linalg import shur as shur -from jax.lax.linalg import svd as svd -from jax.lax.linalg import triangular_solve as triangular_solve -from jax.lax.linalg import tridiagonal as tridiagonal -from jax.lax.linalg import tridiagonal_solve as tridiagonal_solve +from jax.lax.linalg import ( + cholesky as cholesky, + eig as eig, + eigh as eigh, + hessenberg as hessenberg, + householder_product as householder_product, + lu as lu, + qdwh as qdwh, + qr as qr, + shur as shur, + svd as svd, + triangular_solve as triangular_solve, + tridiagonal as tridiagonal, + tridiagonal_solve as tridiagonal_solve, +) diff --git a/src/quaxed/numpy/__init__.pyi b/src/quaxed/numpy/__init__.pyi index 8980a03..237336f 100644 --- a/src/quaxed/numpy/__init__.pyi +++ b/src/quaxed/numpy/__init__.pyi @@ -1,405 +1,408 @@ # Modules -from jax.numpy import fft as fft -from jax.numpy import linalg as linalg +from jax.numpy import fft as fft, linalg as linalg # _core.py # isort: split -from jax.numpy import abs as abs -from jax.numpy import absolute as absolute -from jax.numpy import acos as acos -from jax.numpy import acosh as acosh -from jax.numpy import add as add -from jax.numpy import all as all -from jax.numpy import allclose as allclose -from jax.numpy import amax as amax -from jax.numpy import amin as amin -from jax.numpy import angle as angle -from jax.numpy import any as any -from jax.numpy import append as append -from jax.numpy import apply_along_axis as apply_along_axis -from jax.numpy import apply_over_axes as apply_over_axes -from jax.numpy import arccos as arccos -from jax.numpy import arccosh as arccosh -from jax.numpy import arcsin as arcsin -from jax.numpy import arcsinh as arcsinh -from jax.numpy import arctan as arctan -from jax.numpy import arctan2 as arctan2 -from jax.numpy import arctanh as arctanh -from jax.numpy import argmax as argmax -from jax.numpy import argmin as argmin -from jax.numpy import argpartition as argpartition -from jax.numpy import argsort as argsort -from jax.numpy import argwhere as argwhere -from jax.numpy import around as around -from jax.numpy import array as array -from jax.numpy import array_equal as array_equal -from jax.numpy import array_equiv as array_equiv -from jax.numpy import array_repr as array_repr -from jax.numpy import array_split as array_split -from jax.numpy import asin as asin -from jax.numpy import asinh as asinh -from jax.numpy import astype as astype -from jax.numpy import atan as atan -from jax.numpy import atan2 as atan2 -from jax.numpy import atanh as atanh -from jax.numpy import atleast_1d as atleast_1d -from jax.numpy import atleast_2d as atleast_2d -from jax.numpy import atleast_3d as atleast_3d -from jax.numpy import average as average -from jax.numpy import bartlett as bartlett -from jax.numpy import bfloat16 as bfloat16 -from jax.numpy import bincount as bincount -from jax.numpy import bitwise_and as bitwise_and -from jax.numpy import bitwise_count as bitwise_count -from jax.numpy import bitwise_invert as bitwise_invert -from jax.numpy import bitwise_left_shift as bitwise_left_shift -from jax.numpy import bitwise_not as bitwise_not -from jax.numpy import bitwise_or as bitwise_or -from jax.numpy import bitwise_right_shift as bitwise_right_shift -from jax.numpy import bitwise_xor as bitwise_xor -from jax.numpy import blackman as blackman -from jax.numpy import block as block -from jax.numpy import bool as bool -from jax.numpy import bool_ as bool_ -from jax.numpy import broadcast_arrays as broadcast_arrays -from jax.numpy import broadcast_shapes as broadcast_shapes -from jax.numpy import broadcast_to as broadcast_to -from jax.numpy import c_ as c_ -from jax.numpy import can_cast as can_cast -from jax.numpy import cbrt as cbrt -from jax.numpy import cdouble as cdouble -from jax.numpy import ceil as ceil -from jax.numpy import character as character -from jax.numpy import choose as choose -from jax.numpy import clip as clip -from jax.numpy import column_stack as column_stack -from jax.numpy import complex64 as complex64 -from jax.numpy import complex128 as complex128 -from jax.numpy import complex_ as complex_ -from jax.numpy import complexfloating as complexfloating -from jax.numpy import compress as compress -from jax.numpy import concat as concat -from jax.numpy import concatenate as concatenate -from jax.numpy import conj as conj -from jax.numpy import conjugate as conjugate -from jax.numpy import convolve as convolve -from jax.numpy import copy as copy -from jax.numpy import copysign as copysign -from jax.numpy import corrcoef as corrcoef -from jax.numpy import correlate as correlate -from jax.numpy import cos as cos -from jax.numpy import cosh as cosh -from jax.numpy import count_nonzero as count_nonzero -from jax.numpy import cov as cov -from jax.numpy import cross as cross -from jax.numpy import csingle as csingle -from jax.numpy import cumprod as cumprod -from jax.numpy import cumsum as cumsum -from jax.numpy import deg2rad as deg2rad -from jax.numpy import degrees as degrees -from jax.numpy import delete as delete -from jax.numpy import diag as diag -from jax.numpy import diag_indices as diag_indices -from jax.numpy import diag_indices_from as diag_indices_from -from jax.numpy import diagflat as diagflat -from jax.numpy import diagonal as diagonal -from jax.numpy import diff as diff -from jax.numpy import digitize as digitize -from jax.numpy import divide as divide -from jax.numpy import divmod as divmod -from jax.numpy import dot as dot -from jax.numpy import double as double -from jax.numpy import dsplit as dsplit -from jax.numpy import dstack as dstack -from jax.numpy import dtype as dtype -from jax.numpy import e as e -from jax.numpy import ediff1d as ediff1d -from jax.numpy import einsum as einsum -from jax.numpy import einsum_path as einsum_path -from jax.numpy import empty as empty -from jax.numpy import equal as equal -from jax.numpy import euler_gamma as euler_gamma -from jax.numpy import exp as exp -from jax.numpy import exp2 as exp2 -from jax.numpy import expand_dims as expand_dims -from jax.numpy import expm1 as expm1 -from jax.numpy import extract as extract -from jax.numpy import eye as eye -from jax.numpy import fabs as fabs -from jax.numpy import fill_diagonal as fill_diagonal -from jax.numpy import finfo as finfo -from jax.numpy import fix as fix -from jax.numpy import flatnonzero as flatnonzero -from jax.numpy import flexible as flexible -from jax.numpy import flip as flip -from jax.numpy import fliplr as fliplr -from jax.numpy import flipud as flipud -from jax.numpy import float8_e4m3b11fnuz as float8_e4m3b11fnuz -from jax.numpy import float8_e4m3fn as float8_e4m3fn -from jax.numpy import float8_e4m3fnuz as float8_e4m3fnuz -from jax.numpy import float8_e5m2 as float8_e5m2 -from jax.numpy import float8_e5m2fnuz as float8_e5m2fnuz -from jax.numpy import float16 as float16 -from jax.numpy import float32 as float32 -from jax.numpy import float64 as float64 -from jax.numpy import float_ as float_ -from jax.numpy import float_power as float_power -from jax.numpy import floating as floating -from jax.numpy import floor as floor -from jax.numpy import floor_divide as floor_divide -from jax.numpy import fmax as fmax -from jax.numpy import fmin as fmin -from jax.numpy import fmod as fmod -from jax.numpy import frexp as frexp -from jax.numpy import from_dlpack as from_dlpack -from jax.numpy import frombuffer as frombuffer -from jax.numpy import fromfunction as fromfunction -from jax.numpy import fromiter as fromiter -from jax.numpy import frompyfunc as frompyfunc -from jax.numpy import fromstring as fromstring -from jax.numpy import gcd as gcd -from jax.numpy import generic as generic -from jax.numpy import geomspace as geomspace -from jax.numpy import get_printoptions as get_printoptions -from jax.numpy import gradient as gradient -from jax.numpy import greater as greater -from jax.numpy import greater_equal as greater_equal -from jax.numpy import hamming as hamming -from jax.numpy import hanning as hanning -from jax.numpy import heaviside as heaviside -from jax.numpy import histogram as histogram -from jax.numpy import histogram2d as histogram2d -from jax.numpy import histogram_bin_edges as histogram_bin_edges -from jax.numpy import histogramdd as histogramdd -from jax.numpy import hsplit as hsplit -from jax.numpy import hstack as hstack -from jax.numpy import hypot as hypot -from jax.numpy import i0 as i0 -from jax.numpy import identity as identity -from jax.numpy import iinfo as iinfo -from jax.numpy import imag as imag -from jax.numpy import index_exp as index_exp -from jax.numpy import indices as indices -from jax.numpy import inexact as inexact -from jax.numpy import inf as inf -from jax.numpy import inner as inner -from jax.numpy import insert as insert -from jax.numpy import int4 as int4 -from jax.numpy import int8 as int8 -from jax.numpy import int16 as int16 -from jax.numpy import int32 as int32 -from jax.numpy import int64 as int64 -from jax.numpy import int_ as int_ -from jax.numpy import integer as integer -from jax.numpy import interp as interp -from jax.numpy import intersect1d as intersect1d -from jax.numpy import invert as invert -from jax.numpy import isclose as isclose -from jax.numpy import iscomplex as iscomplex -from jax.numpy import iscomplexobj as iscomplexobj -from jax.numpy import isdtype as isdtype -from jax.numpy import isfinite as isfinite -from jax.numpy import isin as isin -from jax.numpy import isinf as isinf -from jax.numpy import isnan as isnan -from jax.numpy import isneginf as isneginf -from jax.numpy import isposinf as isposinf -from jax.numpy import isreal as isreal -from jax.numpy import isrealobj as isrealobj -from jax.numpy import isscalar as isscalar -from jax.numpy import issubdtype as issubdtype -from jax.numpy import iterable as iterable -from jax.numpy import ix_ as ix_ -from jax.numpy import kaiser as kaiser -from jax.numpy import kron as kron -from jax.numpy import lcm as lcm -from jax.numpy import ldexp as ldexp -from jax.numpy import left_shift as left_shift -from jax.numpy import less as less -from jax.numpy import less_equal as less_equal -from jax.numpy import lexsort as lexsort -from jax.numpy import load as load -from jax.numpy import log as log -from jax.numpy import log1p as log1p -from jax.numpy import log2 as log2 -from jax.numpy import log10 as log10 -from jax.numpy import logaddexp as logaddexp -from jax.numpy import logaddexp2 as logaddexp2 -from jax.numpy import logical_and as logical_and -from jax.numpy import logical_not as logical_not -from jax.numpy import logical_or as logical_or -from jax.numpy import logical_xor as logical_xor -from jax.numpy import logspace as logspace -from jax.numpy import mask_indices as mask_indices -from jax.numpy import matmul as matmul -from jax.numpy import matrix_transpose as matrix_transpose -from jax.numpy import max as max -from jax.numpy import maximum as maximum -from jax.numpy import mean as mean -from jax.numpy import median as median -from jax.numpy import mgrid as mgrid -from jax.numpy import min as min -from jax.numpy import minimum as minimum -from jax.numpy import mod as mod -from jax.numpy import modf as modf -from jax.numpy import moveaxis as moveaxis -from jax.numpy import multiply as multiply -from jax.numpy import nan as nan -from jax.numpy import nan_to_num as nan_to_num -from jax.numpy import nanargmax as nanargmax -from jax.numpy import nanargmin as nanargmin -from jax.numpy import nancumprod as nancumprod -from jax.numpy import nancumsum as nancumsum -from jax.numpy import nanmax as nanmax -from jax.numpy import nanmean as nanmean -from jax.numpy import nanmedian as nanmedian -from jax.numpy import nanmin as nanmin -from jax.numpy import nanpercentile as nanpercentile -from jax.numpy import nanprod as nanprod -from jax.numpy import nanquantile as nanquantile -from jax.numpy import nanstd as nanstd -from jax.numpy import nansum as nansum -from jax.numpy import nanvar as nanvar -from jax.numpy import ndarray as ndarray -from jax.numpy import ndim as ndim -from jax.numpy import negative as negative -from jax.numpy import newaxis as newaxis -from jax.numpy import nextafter as nextafter -from jax.numpy import nonzero as nonzero -from jax.numpy import not_equal as not_equal -from jax.numpy import number as number -from jax.numpy import object_ as object_ -from jax.numpy import ogrid as ogrid -from jax.numpy import ones as ones -from jax.numpy import outer as outer -from jax.numpy import packbits as packbits -from jax.numpy import pad as pad -from jax.numpy import partition as partition -from jax.numpy import percentile as percentile -from jax.numpy import permute_dims as permute_dims -from jax.numpy import pi as pi -from jax.numpy import piecewise as piecewise -from jax.numpy import place as place -from jax.numpy import poly as poly -from jax.numpy import polyadd as polyadd -from jax.numpy import polyder as polyder -from jax.numpy import polydiv as polydiv -from jax.numpy import polyfit as polyfit -from jax.numpy import polyint as polyint -from jax.numpy import polymul as polymul -from jax.numpy import polysub as polysub -from jax.numpy import polyval as polyval -from jax.numpy import positive as positive -from jax.numpy import pow as pow -from jax.numpy import power as power -from jax.numpy import printoptions as printoptions -from jax.numpy import prod as prod -from jax.numpy import promote_types as promote_types -from jax.numpy import ptp as ptp -from jax.numpy import put as put -from jax.numpy import quantile as quantile -from jax.numpy import r_ as r_ -from jax.numpy import rad2deg as rad2deg -from jax.numpy import radians as radians -from jax.numpy import ravel as ravel -from jax.numpy import ravel_multi_index as ravel_multi_index -from jax.numpy import real as real -from jax.numpy import reciprocal as reciprocal -from jax.numpy import remainder as remainder -from jax.numpy import repeat as repeat -from jax.numpy import reshape as reshape -from jax.numpy import resize as resize -from jax.numpy import result_type as result_type -from jax.numpy import right_shift as right_shift -from jax.numpy import rint as rint -from jax.numpy import roll as roll -from jax.numpy import rollaxis as rollaxis -from jax.numpy import roots as roots -from jax.numpy import rot90 as rot90 -from jax.numpy import round as round -from jax.numpy import round_ as round_ -from jax.numpy import s_ as s_ -from jax.numpy import save as save -from jax.numpy import savez as savez -from jax.numpy import searchsorted as searchsorted -from jax.numpy import select as select -from jax.numpy import set_printoptions as set_printoptions -from jax.numpy import setdiff1d as setdiff1d -from jax.numpy import setxor1d as setxor1d -from jax.numpy import shape as shape -from jax.numpy import sign as sign -from jax.numpy import signbit as signbit -from jax.numpy import signedinteger as signedinteger -from jax.numpy import sin as sin -from jax.numpy import sinc as sinc -from jax.numpy import single as single -from jax.numpy import sinh as sinh -from jax.numpy import size as size -from jax.numpy import sort as sort -from jax.numpy import sort_complex as sort_complex -from jax.numpy import split as split -from jax.numpy import sqrt as sqrt -from jax.numpy import square as square -from jax.numpy import squeeze as squeeze -from jax.numpy import stack as stack -from jax.numpy import std as std -from jax.numpy import subtract as subtract -from jax.numpy import sum as sum -from jax.numpy import swapaxes as swapaxes -from jax.numpy import take as take -from jax.numpy import take_along_axis as take_along_axis -from jax.numpy import tan as tan -from jax.numpy import tanh as tanh -from jax.numpy import tensordot as tensordot -from jax.numpy import tile as tile -from jax.numpy import trace as trace -from jax.numpy import transpose as transpose -from jax.numpy import tri as tri -from jax.numpy import tril_indices as tril_indices -from jax.numpy import tril_indices_from as tril_indices_from -from jax.numpy import trim_zeros as trim_zeros -from jax.numpy import triu_indices as triu_indices -from jax.numpy import triu_indices_from as triu_indices_from -from jax.numpy import true_divide as true_divide -from jax.numpy import trunc as trunc -from jax.numpy import uint as uint -from jax.numpy import uint4 as uint4 -from jax.numpy import uint8 as uint8 -from jax.numpy import uint16 as uint16 -from jax.numpy import uint32 as uint32 -from jax.numpy import uint64 as uint64 -from jax.numpy import union1d as union1d -from jax.numpy import unique as unique -from jax.numpy import unique_all as unique_all -from jax.numpy import unique_counts as unique_counts -from jax.numpy import unique_inverse as unique_inverse -from jax.numpy import unique_values as unique_values -from jax.numpy import unpackbits as unpackbits -from jax.numpy import unravel_index as unravel_index -from jax.numpy import unsignedinteger as unsignedinteger -from jax.numpy import unwrap as unwrap -from jax.numpy import vander as vander -from jax.numpy import var as var -from jax.numpy import vdot as vdot -from jax.numpy import vecdot as vecdot -from jax.numpy import vsplit as vsplit -from jax.numpy import vstack as vstack -from jax.numpy import where as where -from jax.numpy import zeros as zeros +from jax.numpy import ( + abs as abs, + absolute as absolute, + acos as acos, + acosh as acosh, + add as add, + all as all, + allclose as allclose, + amax as amax, + amin as amin, + angle as angle, + any as any, + append as append, + apply_along_axis as apply_along_axis, + apply_over_axes as apply_over_axes, + arccos as arccos, + arccosh as arccosh, + arcsin as arcsin, + arcsinh as arcsinh, + arctan as arctan, + arctan2 as arctan2, + arctanh as arctanh, + argmax as argmax, + argmin as argmin, + argpartition as argpartition, + argsort as argsort, + argwhere as argwhere, + around as around, + array as array, + array_equal as array_equal, + array_equiv as array_equiv, + array_repr as array_repr, + array_split as array_split, + asin as asin, + asinh as asinh, + astype as astype, + atan as atan, + atan2 as atan2, + atanh as atanh, + atleast_1d as atleast_1d, + atleast_2d as atleast_2d, + atleast_3d as atleast_3d, + average as average, + bartlett as bartlett, + bfloat16 as bfloat16, + bincount as bincount, + bitwise_and as bitwise_and, + bitwise_count as bitwise_count, + bitwise_invert as bitwise_invert, + bitwise_left_shift as bitwise_left_shift, + bitwise_not as bitwise_not, + bitwise_or as bitwise_or, + bitwise_right_shift as bitwise_right_shift, + bitwise_xor as bitwise_xor, + blackman as blackman, + block as block, + bool as bool, + bool_ as bool_, + broadcast_arrays as broadcast_arrays, + broadcast_shapes as broadcast_shapes, + broadcast_to as broadcast_to, + c_ as c_, + can_cast as can_cast, + cbrt as cbrt, + cdouble as cdouble, + ceil as ceil, + character as character, + choose as choose, + clip as clip, + column_stack as column_stack, + complex64 as complex64, + complex128 as complex128, + complex_ as complex_, + complexfloating as complexfloating, + compress as compress, + concat as concat, + concatenate as concatenate, + conj as conj, + conjugate as conjugate, + convolve as convolve, + copy as copy, + copysign as copysign, + corrcoef as corrcoef, + correlate as correlate, + cos as cos, + cosh as cosh, + count_nonzero as count_nonzero, + cov as cov, + cross as cross, + csingle as csingle, + cumprod as cumprod, + cumsum as cumsum, + deg2rad as deg2rad, + degrees as degrees, + delete as delete, + diag as diag, + diag_indices as diag_indices, + diag_indices_from as diag_indices_from, + diagflat as diagflat, + diagonal as diagonal, + diff as diff, + digitize as digitize, + divide as divide, + divmod as divmod, + dot as dot, + double as double, + dsplit as dsplit, + dstack as dstack, + dtype as dtype, + e as e, + ediff1d as ediff1d, + einsum as einsum, + einsum_path as einsum_path, + empty as empty, + equal as equal, + euler_gamma as euler_gamma, + exp as exp, + exp2 as exp2, + expand_dims as expand_dims, + expm1 as expm1, + extract as extract, + eye as eye, + fabs as fabs, + fill_diagonal as fill_diagonal, + finfo as finfo, + fix as fix, + flatnonzero as flatnonzero, + flexible as flexible, + flip as flip, + fliplr as fliplr, + flipud as flipud, + float8_e4m3b11fnuz as float8_e4m3b11fnuz, + float8_e4m3fn as float8_e4m3fn, + float8_e4m3fnuz as float8_e4m3fnuz, + float8_e5m2 as float8_e5m2, + float8_e5m2fnuz as float8_e5m2fnuz, + float16 as float16, + float32 as float32, + float64 as float64, + float_ as float_, + float_power as float_power, + floating as floating, + floor as floor, + floor_divide as floor_divide, + fmax as fmax, + fmin as fmin, + fmod as fmod, + frexp as frexp, + from_dlpack as from_dlpack, + frombuffer as frombuffer, + fromfunction as fromfunction, + fromiter as fromiter, + frompyfunc as frompyfunc, + fromstring as fromstring, + gcd as gcd, + generic as generic, + geomspace as geomspace, + get_printoptions as get_printoptions, + gradient as gradient, + greater as greater, + greater_equal as greater_equal, + hamming as hamming, + hanning as hanning, + heaviside as heaviside, + histogram as histogram, + histogram2d as histogram2d, + histogram_bin_edges as histogram_bin_edges, + histogramdd as histogramdd, + hsplit as hsplit, + hstack as hstack, + hypot as hypot, + i0 as i0, + identity as identity, + iinfo as iinfo, + imag as imag, + index_exp as index_exp, + indices as indices, + inexact as inexact, + inf as inf, + inner as inner, + insert as insert, + int4 as int4, + int8 as int8, + int16 as int16, + int32 as int32, + int64 as int64, + int_ as int_, + integer as integer, + interp as interp, + intersect1d as intersect1d, + invert as invert, + isclose as isclose, + iscomplex as iscomplex, + iscomplexobj as iscomplexobj, + isdtype as isdtype, + isfinite as isfinite, + isin as isin, + isinf as isinf, + isnan as isnan, + isneginf as isneginf, + isposinf as isposinf, + isreal as isreal, + isrealobj as isrealobj, + isscalar as isscalar, + issubdtype as issubdtype, + iterable as iterable, + ix_ as ix_, + kaiser as kaiser, + kron as kron, + lcm as lcm, + ldexp as ldexp, + left_shift as left_shift, + less as less, + less_equal as less_equal, + lexsort as lexsort, + load as load, + log as log, + log1p as log1p, + log2 as log2, + log10 as log10, + logaddexp as logaddexp, + logaddexp2 as logaddexp2, + logical_and as logical_and, + logical_not as logical_not, + logical_or as logical_or, + logical_xor as logical_xor, + logspace as logspace, + mask_indices as mask_indices, + matmul as matmul, + matrix_transpose as matrix_transpose, + max as max, + maximum as maximum, + mean as mean, + median as median, + mgrid as mgrid, + min as min, + minimum as minimum, + mod as mod, + modf as modf, + moveaxis as moveaxis, + multiply as multiply, + nan as nan, + nan_to_num as nan_to_num, + nanargmax as nanargmax, + nanargmin as nanargmin, + nancumprod as nancumprod, + nancumsum as nancumsum, + nanmax as nanmax, + nanmean as nanmean, + nanmedian as nanmedian, + nanmin as nanmin, + nanpercentile as nanpercentile, + nanprod as nanprod, + nanquantile as nanquantile, + nanstd as nanstd, + nansum as nansum, + nanvar as nanvar, + ndarray as ndarray, + ndim as ndim, + negative as negative, + newaxis as newaxis, + nextafter as nextafter, + nonzero as nonzero, + not_equal as not_equal, + number as number, + object_ as object_, + ogrid as ogrid, + ones as ones, + outer as outer, + packbits as packbits, + pad as pad, + partition as partition, + percentile as percentile, + permute_dims as permute_dims, + pi as pi, + piecewise as piecewise, + place as place, + poly as poly, + polyadd as polyadd, + polyder as polyder, + polydiv as polydiv, + polyfit as polyfit, + polyint as polyint, + polymul as polymul, + polysub as polysub, + polyval as polyval, + positive as positive, + pow as pow, + power as power, + printoptions as printoptions, + prod as prod, + promote_types as promote_types, + ptp as ptp, + put as put, + quantile as quantile, + r_ as r_, + rad2deg as rad2deg, + radians as radians, + ravel as ravel, + ravel_multi_index as ravel_multi_index, + real as real, + reciprocal as reciprocal, + remainder as remainder, + repeat as repeat, + reshape as reshape, + resize as resize, + result_type as result_type, + right_shift as right_shift, + rint as rint, + roll as roll, + rollaxis as rollaxis, + roots as roots, + rot90 as rot90, + round as round, + round_ as round_, + s_ as s_, + save as save, + savez as savez, + searchsorted as searchsorted, + select as select, + set_printoptions as set_printoptions, + setdiff1d as setdiff1d, + setxor1d as setxor1d, + shape as shape, + sign as sign, + signbit as signbit, + signedinteger as signedinteger, + sin as sin, + sinc as sinc, + single as single, + sinh as sinh, + size as size, + sort as sort, + sort_complex as sort_complex, + split as split, + sqrt as sqrt, + square as square, + squeeze as squeeze, + stack as stack, + std as std, + subtract as subtract, + sum as sum, + swapaxes as swapaxes, + take as take, + take_along_axis as take_along_axis, + tan as tan, + tanh as tanh, + tensordot as tensordot, + tile as tile, + trace as trace, + transpose as transpose, + tri as tri, + tril_indices as tril_indices, + tril_indices_from as tril_indices_from, + trim_zeros as trim_zeros, + triu_indices as triu_indices, + triu_indices_from as triu_indices_from, + true_divide as true_divide, + trunc as trunc, + uint as uint, + uint4 as uint4, + uint8 as uint8, + uint16 as uint16, + uint32 as uint32, + uint64 as uint64, + union1d as union1d, + unique as unique, + unique_all as unique_all, + unique_counts as unique_counts, + unique_inverse as unique_inverse, + unique_values as unique_values, + unpackbits as unpackbits, + unravel_index as unravel_index, + unsignedinteger as unsignedinteger, + unwrap as unwrap, + vander as vander, + var as var, + vdot as vdot, + vecdot as vecdot, + vsplit as vsplit, + vstack as vstack, + where as where, + zeros as zeros, +) # _creation_functions.py # isort: split -from jax.numpy import arange as arange -from jax.numpy import asarray as asarray -from jax.numpy import empty_like as empty_like -from jax.numpy import full as full -from jax.numpy import full_like as full_like -from jax.numpy import linspace as linspace -from jax.numpy import meshgrid as meshgrid -from jax.numpy import ones_like as ones_like -from jax.numpy import tril as tril -from jax.numpy import triu as triu -from jax.numpy import zeros_like as zeros_like +from jax.numpy import ( + arange as arange, + asarray as asarray, + empty_like as empty_like, + full as full, + full_like as full_like, + linspace as linspace, + meshgrid as meshgrid, + ones_like as ones_like, + tril as tril, + triu as triu, + zeros_like as zeros_like, +) # _higher_order.py # isort: split diff --git a/src/quaxed/numpy/_higher_order.py b/src/quaxed/numpy/_higher_order.py index 1e7cc9e..763a261 100644 --- a/src/quaxed/numpy/_higher_order.py +++ b/src/quaxed/numpy/_higher_order.py @@ -15,8 +15,7 @@ _parse_input_dimensions, ) -from ._core import asarray, squeeze -from ._core import expand_dims as _expand_dims +from ._core import asarray, expand_dims as _expand_dims, squeeze T = TypeVar("T") diff --git a/src/quaxed/operator.pyi b/src/quaxed/operator.pyi index 9cec5f5..76eb4bc 100644 --- a/src/quaxed/operator.pyi +++ b/src/quaxed/operator.pyi @@ -1,55 +1,57 @@ -from operator import abs as abs -from operator import add as add -from operator import and_ as and_ -from operator import attrgetter as attrgetter -from operator import call as call -from operator import concat as concat -from operator import contains as contains -from operator import countOf as countOf -from operator import delitem as delitem -from operator import eq as eq -from operator import floordiv as floordiv -from operator import ge as ge -from operator import getitem as getitem -from operator import gt as gt -from operator import iadd as iadd -from operator import iand as iand -from operator import iconcat as iconcat -from operator import ifloordiv as ifloordiv -from operator import ilshift as ilshift -from operator import imatmul as imatmul -from operator import imod as imod -from operator import imul as imul -from operator import index as index -from operator import indexOf as indexOf -from operator import inv as inv -from operator import invert as invert -from operator import ior as ior -from operator import ipow as ipow -from operator import irshift as irshift -from operator import is_ as is_ -from operator import is_not as is_not -from operator import isub as isub -from operator import itemgetter as itemgetter -from operator import itruediv as itruediv -from operator import ixor as ixor -from operator import le as le -from operator import length_hint as length_hint -from operator import lshift as lshift -from operator import lt as lt -from operator import matmul as matmul -from operator import methodcaller as methodcaller -from operator import mod as mod -from operator import mul as mul -from operator import ne as ne -from operator import neg as neg -from operator import not_ as not_ -from operator import or_ as or_ -from operator import pos as pos -from operator import pow as pow -from operator import rshift as rshift -from operator import setitem as setitem -from operator import sub as sub -from operator import truediv as truediv -from operator import truth as truth -from operator import xor as xor +from operator import ( + abs as abs, + add as add, + and_ as and_, + attrgetter as attrgetter, + call as call, + concat as concat, + contains as contains, + countOf as countOf, + delitem as delitem, + eq as eq, + floordiv as floordiv, + ge as ge, + getitem as getitem, + gt as gt, + iadd as iadd, + iand as iand, + iconcat as iconcat, + ifloordiv as ifloordiv, + ilshift as ilshift, + imatmul as imatmul, + imod as imod, + imul as imul, + index as index, + indexOf as indexOf, + inv as inv, + invert as invert, + ior as ior, + ipow as ipow, + irshift as irshift, + is_ as is_, + is_not as is_not, + isub as isub, + itemgetter as itemgetter, + itruediv as itruediv, + ixor as ixor, + le as le, + length_hint as length_hint, + lshift as lshift, + lt as lt, + matmul as matmul, + methodcaller as methodcaller, + mod as mod, + mul as mul, + ne as ne, + neg as neg, + not_ as not_, + or_ as or_, + pos as pos, + pow as pow, + rshift as rshift, + setitem as setitem, + sub as sub, + truediv as truediv, + truth as truth, + xor as xor, +) diff --git a/src/quaxed/scipy/special.pyi b/src/quaxed/scipy/special.pyi index c10e3a4..837a512 100644 --- a/src/quaxed/scipy/special.pyi +++ b/src/quaxed/scipy/special.pyi @@ -1,41 +1,43 @@ -from jax.scipy.special import bernoulli as bernoulli -from jax.scipy.special import bessel_jn as bessel_jn -from jax.scipy.special import beta as beta -from jax.scipy.special import betainc as betainc -from jax.scipy.special import betaln as betaln -from jax.scipy.special import digamma as digamma -from jax.scipy.special import entr as entr -from jax.scipy.special import erf as erf -from jax.scipy.special import erfc as erfc -from jax.scipy.special import erfinv as erfinv -from jax.scipy.special import exp1 as exp1 -from jax.scipy.special import expi as expi -from jax.scipy.special import expit as expit -from jax.scipy.special import expn as expn -from jax.scipy.special import factorial as factorial -from jax.scipy.special import gamma as gamma -from jax.scipy.special import gammainc as gammainc -from jax.scipy.special import gammaincc as gammaincc -from jax.scipy.special import gammaln as gammaln -from jax.scipy.special import hyp1f1 as hyp1f1 -from jax.scipy.special import i0 as i0 -from jax.scipy.special import i0e as i0e -from jax.scipy.special import i1 as i1 -from jax.scipy.special import i1e as i1e -from jax.scipy.special import kl_div as kl_div -from jax.scipy.special import log_ndtr as log_ndtr -from jax.scipy.special import logit as logit -from jax.scipy.special import logsumexp as logsumexp -from jax.scipy.special import lpmn as lpmn -from jax.scipy.special import lpmn_values as lpmn_values -from jax.scipy.special import multigammaln as multigammaln -from jax.scipy.special import ndtr as ndtr -from jax.scipy.special import ndtri as ndtri -from jax.scipy.special import poch as poch -from jax.scipy.special import polygamma as polygamma -from jax.scipy.special import rel_entr as rel_entr -from jax.scipy.special import spence as spence -from jax.scipy.special import sph_harm as sph_harm -from jax.scipy.special import xlog1py as xlog1py -from jax.scipy.special import xlogy as xlogy -from jax.scipy.special import zeta as zeta +from jax.scipy.special import ( + bernoulli as bernoulli, + bessel_jn as bessel_jn, + beta as beta, + betainc as betainc, + betaln as betaln, + digamma as digamma, + entr as entr, + erf as erf, + erfc as erfc, + erfinv as erfinv, + exp1 as exp1, + expi as expi, + expit as expit, + expn as expn, + factorial as factorial, + gamma as gamma, + gammainc as gammainc, + gammaincc as gammaincc, + gammaln as gammaln, + hyp1f1 as hyp1f1, + i0 as i0, + i0e as i0e, + i1 as i1, + i1e as i1e, + kl_div as kl_div, + log_ndtr as log_ndtr, + logit as logit, + logsumexp as logsumexp, + lpmn as lpmn, + lpmn_values as lpmn_values, + multigammaln as multigammaln, + ndtr as ndtr, + ndtri as ndtri, + poch as poch, + polygamma as polygamma, + rel_entr as rel_entr, + spence as spence, + sph_harm as sph_harm, + xlog1py as xlog1py, + xlogy as xlogy, + zeta as zeta, +) diff --git a/uv.lock b/uv.lock index 8656737..b689bb9 100644 --- a/uv.lock +++ b/uv.lock @@ -189,7 +189,7 @@ name = "click" version = "8.1.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } wheels = [ @@ -634,7 +634,7 @@ version = "1.6.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, { name = "ghp-import" }, { name = "jinja2" }, { name = "markdown" }, @@ -910,6 +910,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd", size = 71932 }, ] +[[package]] +name = "optype" +version = "0.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/89/42/543e02c72aba7ebe78adb76bbfbed1bc1314eba633ad453984948e5a5f46/optype-0.8.0.tar.gz", hash = "sha256:8cbfd452d6f06c7c70502048f38a0d5451bc601054d3a577dd09c7d6363950e1", size = 85295 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/86/ff/604be975eb0e9fd02358cdacf496f4411db97bffc27279ce260e8f50aba4/optype-0.8.0-py3-none-any.whl", hash = "sha256:90a7760177f2e7feae379a60445fceec37b932b75a00c3d96067497573c5e84d", size = 74228 }, +] + [[package]] name = "packaging" version = "24.1" @@ -1230,12 +1242,13 @@ wheels = [ [[package]] name = "quaxed" -version = "0.6.9.dev2+gc21cb53" +version = "0.7.2.dev2+g1c9e9d5.d20250120" source = { editable = "." } dependencies = [ { name = "jax" }, { name = "jaxlib" }, { name = "jaxtyping" }, + { name = "optype" }, { name = "plum-dispatch" }, { name = "quax" }, ] @@ -1284,6 +1297,7 @@ requires-dist = [ { name = "jax", specifier = ">=0.4.3" }, { name = "jaxlib", specifier = ">=0.4.3" }, { name = "jaxtyping", specifier = ">=0.2.34" }, + { name = "optype", specifier = ">=0.8.0" }, { name = "plum-dispatch", specifier = ">=2.5.2" }, { name = "quax", specifier = ">=0.0.5" }, ]