-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Nathaniel Starkman <nstarman@users.noreply.github.com>
- Loading branch information
Showing
8 changed files
with
722 additions
and
671 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,178 +1,202 @@ | ||
# ----- Operators ----- | ||
# isort: split | ||
from jax.lax.linalg import abs as abs | ||
from jax.lax.linalg import acos as acos | ||
from jax.lax.linalg import acosh as acosh | ||
from jax.lax.linalg import add as add | ||
from jax.lax.linalg import approx_max_k as approx_max_k | ||
from jax.lax.linalg import approx_min_k as approx_min_k | ||
from jax.lax.linalg import argmax as argmax | ||
from jax.lax.linalg import argmin as argmin | ||
from jax.lax.linalg import asin as asin | ||
from jax.lax.linalg import asinh as asinh | ||
from jax.lax.linalg import atan as atan | ||
from jax.lax.linalg import atan2 as atan2 | ||
from jax.lax.linalg import atanh as atanh | ||
from jax.lax.linalg import batch_matmul as batch_matmul | ||
from jax.lax.linalg import bessel_i0e as bessel_i0e | ||
from jax.lax.linalg import bessel_i1e as bessel_i1e | ||
from jax.lax.linalg import betainc as betainc | ||
from jax.lax.linalg import bitcast_convert_type as bitcast_convert_type | ||
from jax.lax.linalg import bitwise_and as bitwise_and | ||
from jax.lax.linalg import bitwise_not as bitwise_not | ||
from jax.lax.linalg import bitwise_or as bitwise_or | ||
from jax.lax.linalg import bitwise_xor as bitwise_xor | ||
from jax.lax.linalg import broadcast as broadcast | ||
from jax.lax.linalg import broadcast_in_dim as broadcast_in_dim | ||
from jax.lax.linalg import broadcast_shapes as broadcast_shapes | ||
from jax.lax.linalg import broadcast_to_rank as broadcast_to_rank | ||
from jax.lax.linalg import broadcasted_iota as broadcasted_iota | ||
from jax.lax.linalg import cbrt as cbrt | ||
from jax.lax.linalg import ceil as ceil | ||
from jax.lax.linalg import clamp as clamp | ||
from jax.lax.linalg import clz as clz | ||
from jax.lax.linalg import collapse as collapse | ||
from jax.lax.linalg import complex as complex | ||
from jax.lax.linalg import concatenate as concatenate | ||
from jax.lax.linalg import conj as conj | ||
from jax.lax.linalg import conv as conv | ||
from jax.lax.linalg import conv_dimension_numbers as conv_dimension_numbers | ||
from jax.lax.linalg import conv_general_dilated as conv_general_dilated | ||
from jax.lax.linalg import conv_general_dilated_local as conv_general_dilated_local | ||
from jax.lax.linalg import conv_general_dilated_patches as conv_general_dilated_patches | ||
from jax.lax.linalg import conv_transpose as conv_transpose | ||
from jax.lax.linalg import conv_with_general_padding as conv_with_general_padding | ||
from jax.lax.linalg import convert_element_type as convert_element_type | ||
from jax.lax.linalg import cos as cos | ||
from jax.lax.linalg import cosh as cosh | ||
from jax.lax.linalg import cumlogsumexp as cumlogsumexp | ||
from jax.lax.linalg import cummax as cummax | ||
from jax.lax.linalg import cummin as cummin | ||
from jax.lax.linalg import cumprod as cumprod | ||
from jax.lax.linalg import cumsum as cumsum | ||
from jax.lax.linalg import digamma as digamma | ||
from jax.lax.linalg import div as div | ||
from jax.lax.linalg import dot as dot | ||
from jax.lax.linalg import dot_general as dot_general | ||
from jax.lax.linalg import dynamic_index_in_dim as dynamic_index_in_dim | ||
from jax.lax.linalg import dynamic_slice as dynamic_slice | ||
from jax.lax.linalg import dynamic_slice_in_dim as dynamic_slice_in_dim | ||
from jax.lax.linalg import dynamic_update_index_in_dim as dynamic_update_index_in_dim | ||
from jax.lax.linalg import dynamic_update_slice as dynamic_update_slice | ||
from jax.lax.linalg import dynamic_update_slice_in_dim as dynamic_update_slice_in_dim | ||
from jax.lax.linalg import eq as eq | ||
from jax.lax.linalg import erf as erf | ||
from jax.lax.linalg import erf_inv as erf_inv | ||
from jax.lax.linalg import erfc as erfc | ||
from jax.lax.linalg import exp as exp | ||
from jax.lax.linalg import expand_dims as expand_dims | ||
from jax.lax.linalg import expm1 as expm1 | ||
from jax.lax.linalg import fft as fft | ||
from jax.lax.linalg import floor as floor | ||
from jax.lax.linalg import full as full | ||
from jax.lax.linalg import full_like as full_like | ||
from jax.lax.linalg import gather as gather | ||
from jax.lax.linalg import ge as ge | ||
from jax.lax.linalg import gt as gt | ||
from jax.lax.linalg import igamma as igamma | ||
from jax.lax.linalg import igammac as igammac | ||
from jax.lax.linalg import imag as imag | ||
from jax.lax.linalg import index_in_dim as index_in_dim | ||
from jax.lax.linalg import index_take as index_take | ||
from jax.lax.linalg import integer_pow as integer_pow | ||
from jax.lax.linalg import iota as iota | ||
from jax.lax.linalg import is_finite as is_finite | ||
from jax.lax.linalg import le as le | ||
from jax.lax.linalg import lgamma as lgamma | ||
from jax.lax.linalg import log as log | ||
from jax.lax.linalg import log1p as log1p | ||
from jax.lax.linalg import logistic as logistic | ||
from jax.lax.linalg import lt as lt | ||
from jax.lax.linalg import max as max | ||
from jax.lax.linalg import min as min | ||
from jax.lax.linalg import mul as mul | ||
from jax.lax.linalg import neg as neg | ||
from jax.lax.linalg import nextafter as nextafter | ||
from jax.lax.linalg import pad as pad | ||
from jax.lax.linalg import polygamma as polygamma | ||
from jax.lax.linalg import population_count as population_count | ||
from jax.lax.linalg import pow as pow | ||
from jax.lax.linalg import random_gamma_grad as random_gamma_grad | ||
from jax.lax.linalg import real as real | ||
from jax.lax.linalg import reciprocal as reciprocal | ||
from jax.lax.linalg import reduce as reduce | ||
from jax.lax.linalg import reduce_precision as reduce_precision | ||
from jax.lax.linalg import reduce_window as reduce_window | ||
from jax.lax.linalg import rem as rem | ||
from jax.lax.linalg import reshape as reshape | ||
from jax.lax.linalg import rev as rev | ||
from jax.lax.linalg import rng_bit_generator as rng_bit_generator | ||
from jax.lax.linalg import rng_uniform as rng_uniform | ||
from jax.lax.linalg import round as round | ||
from jax.lax.linalg import rsqrt as rsqrt | ||
from jax.lax.linalg import scatter as scatter | ||
from jax.lax.linalg import scatter_add as scatter_add | ||
from jax.lax.linalg import scatter_apply as scatter_apply | ||
from jax.lax.linalg import scatter_max as scatter_max | ||
from jax.lax.linalg import scatter_min as scatter_min | ||
from jax.lax.linalg import scatter_mul as scatter_mul | ||
from jax.lax.linalg import shift_left as shift_left | ||
from jax.lax.linalg import shift_right_arithmetic as shift_right_arithmetic | ||
from jax.lax.linalg import shift_right_logical as shift_right_logical | ||
from jax.lax.linalg import sign as sign | ||
from jax.lax.linalg import sin as sin | ||
from jax.lax.linalg import sinh as sinh | ||
from jax.lax.linalg import slice as slice | ||
from jax.lax.linalg import slice_in_dim as slice_in_dim | ||
from jax.lax.linalg import sort as sort | ||
from jax.lax.linalg import sort_key_val as sort_key_val | ||
from jax.lax.linalg import sqrt as sqrt | ||
from jax.lax.linalg import square as square | ||
from jax.lax.linalg import squeeze as squeeze | ||
from jax.lax.linalg import sub as sub | ||
from jax.lax.linalg import tan as tan | ||
from jax.lax.linalg import tanh as tanh | ||
from jax.lax.linalg import top_k as top_k | ||
from jax.lax.linalg import transpose as transpose | ||
from jax.lax.linalg import zeros_like_array as zeros_like_array | ||
from jax.lax.linalg import zeta as zeta | ||
from jax.lax import ( | ||
abs as abs, | ||
acos as acos, | ||
acosh as acosh, | ||
add as add, | ||
approx_max_k as approx_max_k, | ||
approx_min_k as approx_min_k, | ||
argmax as argmax, | ||
argmin as argmin, | ||
asin as asin, | ||
asinh as asinh, | ||
atan as atan, | ||
atan2 as atan2, | ||
atanh as atanh, | ||
batch_matmul as batch_matmul, | ||
bessel_i0e as bessel_i0e, | ||
bessel_i1e as bessel_i1e, | ||
betainc as betainc, | ||
bitcast_convert_type as bitcast_convert_type, | ||
bitwise_and as bitwise_and, | ||
bitwise_not as bitwise_not, | ||
bitwise_or as bitwise_or, | ||
bitwise_xor as bitwise_xor, | ||
broadcast as broadcast, | ||
broadcast_in_dim as broadcast_in_dim, | ||
broadcast_shapes as broadcast_shapes, | ||
broadcast_to_rank as broadcast_to_rank, | ||
broadcasted_iota as broadcasted_iota, | ||
cbrt as cbrt, | ||
ceil as ceil, | ||
clamp as clamp, | ||
clz as clz, | ||
collapse as collapse, | ||
complex as complex, | ||
concatenate as concatenate, | ||
conj as conj, | ||
conv as conv, | ||
conv_dimension_numbers as conv_dimension_numbers, | ||
conv_general_dilated as conv_general_dilated, | ||
conv_general_dilated_local as conv_general_dilated_local, | ||
conv_general_dilated_patches as conv_general_dilated_patches, | ||
conv_transpose as conv_transpose, | ||
conv_with_general_padding as conv_with_general_padding, | ||
convert_element_type as convert_element_type, | ||
cos as cos, | ||
cosh as cosh, | ||
cumlogsumexp as cumlogsumexp, | ||
cummax as cummax, | ||
cummin as cummin, | ||
cumprod as cumprod, | ||
cumsum as cumsum, | ||
digamma as digamma, | ||
div as div, | ||
dot as dot, | ||
dot_general as dot_general, | ||
dynamic_index_in_dim as dynamic_index_in_dim, | ||
dynamic_slice as dynamic_slice, | ||
dynamic_slice_in_dim as dynamic_slice_in_dim, | ||
dynamic_update_index_in_dim as dynamic_update_index_in_dim, | ||
dynamic_update_slice as dynamic_update_slice, | ||
dynamic_update_slice_in_dim as dynamic_update_slice_in_dim, | ||
eq as eq, | ||
erf as erf, | ||
erf_inv as erf_inv, | ||
erfc as erfc, | ||
exp as exp, | ||
expand_dims as expand_dims, | ||
expm1 as expm1, | ||
fft as fft, | ||
floor as floor, | ||
full as full, | ||
full_like as full_like, | ||
gather as gather, | ||
ge as ge, | ||
gt as gt, | ||
igamma as igamma, | ||
igammac as igammac, | ||
imag as imag, | ||
index_in_dim as index_in_dim, | ||
index_take as index_take, | ||
integer_pow as integer_pow, | ||
iota as iota, | ||
is_finite as is_finite, | ||
le as le, | ||
lgamma as lgamma, | ||
log as log, | ||
log1p as log1p, | ||
logistic as logistic, | ||
lt as lt, | ||
max as max, | ||
min as min, | ||
mul as mul, | ||
ne as ne, | ||
neg as neg, | ||
nextafter as nextafter, | ||
pad as pad, | ||
polygamma as polygamma, | ||
population_count as population_count, | ||
pow as pow, | ||
random_gamma_grad as random_gamma_grad, | ||
real as real, | ||
reciprocal as reciprocal, | ||
reduce as reduce, | ||
reduce_precision as reduce_precision, | ||
reduce_window as reduce_window, | ||
rem as rem, | ||
reshape as reshape, | ||
rev as rev, | ||
rng_bit_generator as rng_bit_generator, | ||
rng_uniform as rng_uniform, | ||
round as round, | ||
rsqrt as rsqrt, | ||
scatter as scatter, | ||
scatter_add as scatter_add, | ||
scatter_apply as scatter_apply, | ||
scatter_max as scatter_max, | ||
scatter_min as scatter_min, | ||
scatter_mul as scatter_mul, | ||
shift_left as shift_left, | ||
shift_right_arithmetic as shift_right_arithmetic, | ||
shift_right_logical as shift_right_logical, | ||
sign as sign, | ||
sin as sin, | ||
sinh as sinh, | ||
slice as slice, | ||
slice_in_dim as slice_in_dim, | ||
sort as sort, | ||
sort_key_val as sort_key_val, | ||
sqrt as sqrt, | ||
square as square, | ||
squeeze as squeeze, | ||
sub as sub, | ||
tan as tan, | ||
tanh as tanh, | ||
top_k as top_k, | ||
transpose as transpose, | ||
zeros_like_array as zeros_like_array, | ||
zeta as zeta, | ||
) | ||
|
||
# ----- Control Flow Operators ----- | ||
# isort: split | ||
from jax.lax.linalg import associative_scan as associative_scan | ||
from jax.lax.linalg import cond as cond | ||
from jax.lax.linalg import fori_loop as fori_loop | ||
from jax.lax.linalg import map as map | ||
from jax.lax.linalg import scan as scan | ||
from jax.lax.linalg import select as select | ||
from jax.lax.linalg import select_n as select_n | ||
from jax.lax.linalg import switch as switch | ||
from jax.lax.linalg import while_loop as while_loop | ||
from jax.lax import ( | ||
associative_scan as associative_scan, | ||
cond as cond, | ||
fori_loop as fori_loop, | ||
map as map, | ||
scan as scan, | ||
select as select, | ||
select_n as select_n, | ||
switch as switch, | ||
while_loop as while_loop, | ||
) | ||
|
||
# ----- Custom Gradient Operators ----- | ||
# isort: split | ||
from jax.lax.linalg import custom_linear_solve as custom_linear_solve | ||
from jax.lax.linalg import custom_root as custom_root | ||
from jax.lax.linalg import stop_gradient as stop_gradient | ||
from jax.lax import ( | ||
custom_linear_solve as custom_linear_solve, | ||
custom_root as custom_root, | ||
stop_gradient as stop_gradient, | ||
) | ||
|
||
# ----- Parallel Operators ----- | ||
# isort: split | ||
from jax.lax.linalg import all_gather as all_gather | ||
from jax.lax.linalg import all_to_all as all_to_all | ||
from jax.lax.linalg import axis_index as axis_index | ||
from jax.lax.linalg import pmax as pmax | ||
from jax.lax.linalg import pmean as pmean | ||
from jax.lax.linalg import pmin as pmin | ||
from jax.lax.linalg import ppermute as ppermute | ||
from jax.lax.linalg import pshuffle as pshuffle | ||
from jax.lax.linalg import psum as psum | ||
from jax.lax.linalg import psum_scatter as psum_scatter | ||
from jax.lax.linalg import pswapaxes as pswapaxes | ||
from jax.lax import ( | ||
all_gather as all_gather, | ||
all_to_all as all_to_all, | ||
axis_index as axis_index, | ||
pmax as pmax, | ||
pmean as pmean, | ||
pmin as pmin, | ||
ppermute as ppermute, | ||
pshuffle as pshuffle, | ||
psum as psum, | ||
psum_scatter as psum_scatter, | ||
pswapaxes as pswapaxes, | ||
) | ||
|
||
# ----- Sharding-related Operators ----- | ||
# isort: split | ||
from jax.lax.linalg import with_sharding_constraint as with_sharding_constraint | ||
from jax.lax import with_sharding_constraint as with_sharding_constraint | ||
|
||
# ----- Linear Algebra Operators ----- | ||
# isort: split | ||
from jax.lax.linalg import linalg as linalg | ||
from jax.lax import linalg as linalg | ||
|
||
# ----- Argument classes ----- | ||
# isort: split | ||
from jax.lax import ( | ||
ConvDimensionNumbers as ConvDimensionNumbers, | ||
ConvGeneralDilatedDimensionNumbers as ConvGeneralDilatedDimensionNumbers, | ||
DotAlgorithm as DotAlgorithm, | ||
GatherDimensionNumbers as GatherDimensionNumbers, | ||
GatherScatterMode as GatherScatterMode, | ||
Precision as Precision, | ||
PrecisionLike as PrecisionLike, | ||
RandomAlgorithm as RandomAlgorithm, | ||
RoundingMethod as RoundingMethod, | ||
ScatterDimensionNumbers as ScatterDimensionNumbers, | ||
) |
Oops, something went wrong.