Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: change jax backend .astype method calls to jnp.astype #28796

Merged
merged 1 commit into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ivy/functional/backends/jax/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,18 @@ def softplus(
res = (
jnp.add(
jnp.log1p(jnp.exp(-jnp.abs(x_beta))),
jnp.maximum(x_beta, 0).astype(x.dtype),
jnp.astype(jnp.maximum(x_beta, 0), x.dtype),
)
) / beta
else:
x_beta = x
res = jnp.add(
jnp.log1p(jnp.exp(-jnp.abs(x_beta))),
jnp.maximum(x_beta, 0).astype(x.dtype),
jnp.astype(jnp.maximum(x_beta, 0), x.dtype),
)
if threshold is not None:
return jnp.where(x_beta > threshold, x, res).astype(x.dtype)
return res.astype(x.dtype)
return jnp.astype(jnp.where(x_beta > threshold, x, res), x.dtype)
return jnp.astype(res, x.dtype)


# Softsign
Expand Down
6 changes: 3 additions & 3 deletions ivy/functional/backends/jax/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def arange(
res = jnp.arange(start, stop, step, dtype=dtype)
if not dtype:
if res.dtype == jnp.float64:
return res.astype(jnp.float32)
return jnp.astype(res, jnp.float32)
elif res.dtype == jnp.int64:
return res.astype(jnp.int32)
return jnp.astype(res, jnp.int32)
return res


Expand Down Expand Up @@ -202,7 +202,7 @@ def linspace(
# but can lead to rounding errors for integer outputs.
real_dtype = jnp.finfo(computation_dtype).dtype
step = jnp.reshape(jax.lax.iota(real_dtype, div), iota_shape) / div
step = step.astype(computation_dtype)
step = jnp.astype(step, computation_dtype)
start_reshaped = jnp.reshape(broadcast_start, bounds_shape)
end_reshaped = jnp.reshape(broadcast_stop, bounds_shape)
out = start_reshaped + step * (end_reshaped - start_reshaped)
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/jax/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def astype(
ivy.utils.assertions._check_jax_x64_flag(dtype)
if x.dtype == dtype:
return jnp.copy(x) if copy else x
return x.astype(dtype)
return jnp.astype(x, dtype)


def broadcast_arrays(*arrays: JaxArray) -> List[JaxArray]:
Expand Down
16 changes: 7 additions & 9 deletions ivy/functional/backends/jax/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def floor_divide(
out: Optional[JaxArray] = None,
) -> JaxArray:
x1, x2 = ivy.promote_types_of_inputs(x1, x2)
return jnp.floor(jnp.divide(x1, x2)).astype(x1.dtype)
return jnp.astype(jnp.floor(jnp.divide(x1, x2)), x1.dtype)


def fmin(
Expand Down Expand Up @@ -331,8 +331,8 @@ def logaddexp2(
) -> JaxArray:
x1, x2 = promote_types_of_inputs(x1, x2)
if not is_float_dtype(x1):
x1 = x1.astype(default_float_dtype(as_native=True))
x2 = x2.astype(default_float_dtype(as_native=True))
x1 = jnp.astype(x1, default_float_dtype(as_native=True))
x2 = jnp.astype(x2, default_float_dtype(as_native=True))
return jnp.logaddexp2(x1, x2)


Expand Down Expand Up @@ -424,11 +424,9 @@ def pow(
else:
fill_value = jnp.finfo(x1.dtype).min
ret = jnp.float_power(x1, x2)
return jnp.where(jnp.bitwise_and(x1 == 0, x2 < 0), fill_value, ret).astype(
x1.dtype
)
return jnp.astype(jnp.where(jnp.bitwise_and(x1 == 0, x2 < 0), fill_value, ret), x1.dtype)
if ivy.is_int_dtype(x1) and ivy.any(x2 < 0):
return jnp.float_power(x1, x2).astype(x1.dtype)
return jnp.astype(jnp.float_power(x1, x2), x1.dtype)
return jnp.power(x1, x2)


Expand All @@ -447,7 +445,7 @@ def remainder(
res_floored = jnp.where(res >= 0, jnp.floor(res), jnp.ceil(res))
diff = res - res_floored
diff, x2 = ivy.promote_types_of_inputs(diff, x2)
return jnp.round(diff * x2).astype(x1.dtype)
return jnp.astype(jnp.round(diff * x2), x1.dtype)
return jnp.remainder(x1, x2)


Expand All @@ -472,7 +470,7 @@ def sign(
) -> JaxArray:
if "complex" in str(x.dtype):
return jnp.sign(x) if np_variant else _abs_variant_sign(x)
return jnp.where(x == -0.0, 0.0, jnp.sign(x)).astype(x.dtype)
return jnp.astype(jnp.where(x == -0.0, 0.0, jnp.sign(x)), x.dtype)


def sin(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
Expand Down
26 changes: 13 additions & 13 deletions ivy/functional/backends/jax/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def custom_grad_func(x_and_grad, one):

new_func = ivy.bind_custom_gradient_function(relu6_func, custom_grad_func)

return new_func(x).astype(x.dtype)
return jnp.astype(new_func(x), x.dtype)


def thresholded_relu(
Expand All @@ -50,7 +50,7 @@ def thresholded_relu(
threshold: Union[int, float] = 0,
out: Optional[JaxArray] = None,
) -> JaxArray:
return jnp.where(x > threshold, x, 0).astype(x.dtype)
return jnp.astype(jnp.where(x > threshold, x, 0), x.dtype)


def logsigmoid(
Expand All @@ -60,16 +60,16 @@ def logsigmoid(


def selu(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
ret = jax.nn.selu(x).astype(x.dtype)
ret = jnp.astype(jax.nn.selu(x), x.dtype)
if ivy.exists(out):
return ivy.inplace_update(out, ret).astype(x.dtype)
return jnp.astype(ivy.inplace_update(out, ret), x.dtype)
return ret


def silu(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
ret = jax.nn.silu(x)
if ivy.exists(out):
return ivy.inplace_update(out, ret).astype(x.dtype)
return jnp.astype(ivy.inplace_update(out, ret), x.dtype)
return ret


Expand All @@ -79,7 +79,7 @@ def elu(
) -> JaxArray:
ret = jax.nn.elu(x, alpha)
if ivy.exists(out):
return ivy.inplace_update(out, ret).astype(x.dtype)
return jnp.astype(ivy.inplace_update(out, ret), x.dtype)
return ret


Expand All @@ -105,14 +105,14 @@ def hardtanh(
) -> JaxArray:
ret = jnp.where(x > max_val, max_val, jnp.where(x < min_val, min_val, x))
if ivy.exists(out):
return ivy.inplace_update(out, ret).astype(x.dtype)
return ivy.astype(ivy.inplace_update(out, ret), x.dtype)
return ivy.astype(ret, x.dtype)


def tanhshrink(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
ret = jnp.subtract(x, jax.nn.tanh(x))
if ivy.exists(out):
return ivy.inplace_update(out, ret).astype(x.dtype)
return jnp.astype(ivy.inplace_update(out, ret), x.dtype)
return ret


Expand All @@ -124,9 +124,9 @@ def threshold(
value: Union[int, float],
out: Optional[JaxArray] = None,
) -> JaxArray:
ret = jnp.where(x > threshold, x, value).astype(x.dtype)
ret = jnp.astype(jnp.where(x > threshold, x, value), x.dtype)
if ivy.exists(out):
return ivy.inplace_update(out, ret).astype(x.dtype) # type: ignore
return jnp.astype(ivy.inplace_update(out, ret), x.dtype) # type: ignore
return ret


Expand All @@ -136,7 +136,7 @@ def softshrink(
) -> JaxArray:
ret = jnp.where(x > lambd, x - lambd, jnp.where(x < -lambd, x + lambd, 0))
if ivy.exists(out):
return ivy.inplace_update(out, ret).astype(x.dtype)
return jnp.astype(ivy.inplace_update(out, ret), x.dtype)
return ret


Expand All @@ -158,13 +158,13 @@ def hardshrink(
) -> JaxArray:
ret = jnp.where(x > lambd, x, jnp.where(x < -lambd, x, 0))
if ivy.exists(out):
return ivy.inplace_update(out, ret).astype(x.dtype)
return jnp.astype(ivy.inplace_update(out, ret), x.dtype)
return ret


@with_unsupported_dtypes({"0.4.16 and below": ("float16", "bfloat16")}, backend_version)
def hardsilu(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
ret = jax.nn.hard_silu(x)
if ivy.exists(out):
return ivy.inplace_update(out, ret).astype(x.dtype)
return jnp.astype(ivy.inplace_update(out, ret), x.dtype)
return ret
13 changes: 7 additions & 6 deletions ivy/functional/backends/jax/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def hann_window(
count = jnp.arange(size) / size
else:
count = jnp.linspace(start=0, stop=size, num=size)
return (0.5 - 0.5 * jnp.cos(2 * jnp.pi * count)).astype(dtype)
return jnp.astype((0.5 - 0.5 * jnp.cos(2 * jnp.pi * count)), dtype)


def kaiser_window(
Expand All @@ -61,9 +61,9 @@ def kaiser_window(
if window_length < 2:
return jnp.ones([window_length], dtype=dtype)
if periodic is False:
return jnp.kaiser(M=window_length, beta=beta).astype(dtype)
return jnp.astype(jnp.kaiser(M=window_length, beta=beta), dtype)
else:
return jnp.kaiser(M=window_length + 1, beta=beta)[:-1].astype(dtype)
return jnp.astype(jnp.kaiser(M=window_length + 1, beta=beta)[:-1], dtype)


def tril_indices(
Expand Down Expand Up @@ -118,10 +118,11 @@ def blackman_window(
count = jnp.arange(size) / size
else:
count = jnp.linspace(start=0, stop=size, num=size)
return (
return jnp.astype(
(0.42 - 0.5 * jnp.cos(2 * jnp.pi * count))
+ (0.08 * jnp.cos(2 * jnp.pi * 2 * count))
).astype(dtype)
+ (0.08 * jnp.cos(2 * jnp.pi * 2 * count)),
dtype
)


def trilu(
Expand Down
10 changes: 5 additions & 5 deletions ivy/functional/backends/jax/experimental/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def float_power(
out_dtype = jnp.complex128
else:
out_dtype = jnp.float64
return jnp.float_power(x1, x2).astype(out_dtype)
return jnp.astype(jnp.float_power(x1, x2), out_dtype)


def copysign(
Expand All @@ -91,8 +91,8 @@ def copysign(
) -> JaxArray:
x1, x2 = promote_types_of_inputs(x1, x2)
if not is_float_dtype(x1):
x1 = x1.astype(default_float_dtype(as_native=True))
x2 = x2.astype(default_float_dtype(as_native=True))
x1 = jnp.astype(x1, default_float_dtype(as_native=True))
x2 = jnp.astype(x2, default_float_dtype(as_native=True))
return jnp.copysign(x1, x2)


Expand Down Expand Up @@ -307,7 +307,7 @@ def gradient(
if jnp.issubdtype(distances.dtype, jnp.integer):
# Convert numpy integer types to float64 to avoid modular
# arithmetic in np.diff(distances).
distances = distances.astype(jnp.float64)
distances = jnp.astype(distances, jnp.float64)
diffx = jnp.diff(distances)
# if distances are constant reduce to the scalar case
# since it brings a consistent speedup
Expand All @@ -333,7 +333,7 @@ def gradient(

otype = f.dtype
if jnp.issubdtype(otype, jnp.integer):
f = f.astype(jnp.float64)
f = jnp.astype(f, jnp.float64)

for axis, ax_dx in zip(axes, dx):
if f.shape[axis] < edge_order + 1:
Expand Down
18 changes: 9 additions & 9 deletions ivy/functional/backends/jax/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def general_pool(
if not ivy.is_array(init):
init = jnp.array(init, dtype=inputs.dtype)
promoted_type = jnp.promote_types(inputs.dtype, init.dtype)
inputs = inputs.astype(promoted_type)
init = init.astype(promoted_type)
inputs = jnp.astype(inputs, promoted_type)
init = jnp.astype(init, promoted_type)
y = jlax.reduce_window(
inputs, init, reduce_fn, dims, strides, pad_list, window_dilation=dilation
)
Expand Down Expand Up @@ -238,7 +238,7 @@ def max_pool2d(
if data_format == "NCHW":
res = jnp.transpose(res, (0, 3, 1, 2))

return res.astype(odtype)
return jnp.astype(res, odtype)


def max_pool3d(
Expand Down Expand Up @@ -331,7 +331,7 @@ def avg_pool1d(
if data_format in ("NCW", "NCL"):
res = jnp.transpose(res, (0, 2, 1))
if x.dtype == "float16":
res = res.astype("float16")
res = jnp.astype(res, "float16")

return res

Expand Down Expand Up @@ -765,7 +765,7 @@ def reduce_window(
padding = _to_nested_tuple(padding)
return jlax.reduce_window(
operand,
jnp.array(init_value).astype(operand.dtype),
jnp.astype(jnp.array(init_value), operand.dtype),
computation,
window_dimensions,
window_strides,
Expand Down Expand Up @@ -807,7 +807,7 @@ def fft2(
raise ivy.utils.exceptions.IvyError(
f"Invalid data points {s}, expecting s points larger than 1"
)
return jnp.fft.fft2(x, s, dim, norm).astype(jnp.complex128)
return jnp.astype(jnp.fft.fft2(x, s, dim, norm), jnp.complex128)


def ifftn(
Expand Down Expand Up @@ -859,12 +859,12 @@ def rfft(
) -> JaxArray:
x = x.real
if x.dtype == jnp.float16:
x = x.astype(jnp.float32)
x = jnp.astype(x, jnp.float32)

ret = jnp.fft.rfft(x, n=n, axis=axis, norm=norm)

if x.dtype != jnp.float64:
ret = ret.astype(jnp.complex64)
ret = jnp.astype(ret, jnp.complex64)
if ivy.exists(out):
return ivy.inplace_update(out, ret)
return ret
Expand Down Expand Up @@ -900,7 +900,7 @@ def rfftn(
)
if norm not in {"backward", "ortho", "forward"}:
raise ivy.utils.exceptions.IvyError(f"Unrecognized normalization mode {norm}")
return jnp.fft.rfftn(x, s, axes, norm).astype(jnp.complex128)
return jnp.astype(jnp.fft.rfftn(x, s, axes, norm), jnp.complex128)


# stft
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/jax/experimental/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def diagflat(
mode="constant",
)

ret = output_array.astype(x.dtype)
ret = jnp.astype(output_array, x.dtype)
if ivy.exists(out):
ivy.inplace_update(out, ret)

Expand Down
8 changes: 4 additions & 4 deletions ivy/functional/backends/jax/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def pad(
mode=mode,
)
if jnp.issubdtype(input_dtype, jnp.integer) and mode in ["mean", "median"]:
ret = jnp.round(ret).astype(input_dtype)
ret = jnp.astype(jnp.round(ret), input_dtype)
return ret


Expand Down Expand Up @@ -389,7 +389,7 @@ def unique_consecutive(
if x_shape:
inverse_indices = jnp.reshape(inverse_indices, x_shape)
return Results(
output.astype(x.dtype),
jnp.astype(output, x.dtype),
inverse_indices,
counts,
)
Expand All @@ -411,7 +411,7 @@ def fill_diagonal(
else:
step = 1 + (jnp.cumprod(shape[:-1])).sum()
a = jnp.reshape(a, (-1,))
a = a.at[:end:step].set(jnp.array(v).astype(a.dtype))
a = a.at[:end:step].set(jnp.astype(jnp.array(v), a.dtype))
a = jnp.reshape(a, shape)
return a

Expand All @@ -435,7 +435,7 @@ def take(
if not isinstance(indices, JaxArray):
indices = jnp.array(indices)
if jnp.issubdtype(indices.dtype, jnp.floating):
indices = indices.astype(jnp.int64)
indices = jnp.astype(indices, jnp.int64)

# raise
if mode == "raise":
Expand Down
Loading
Loading