Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(layers): Fix compute_dtype handling for complex-valued inputs in Dense layer #20823

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions keras/api/_tf_keras/keras/random/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from keras.src.random.random import beta
from keras.src.random.random import binomial
from keras.src.random.random import categorical
from keras.src.random.random import complex_uniform
from keras.src.random.random import dropout
from keras.src.random.random import gamma
from keras.src.random.random import initializer_for_complex
from keras.src.random.random import normal
from keras.src.random.random import randint
from keras.src.random.random import shuffle
Expand Down
2 changes: 2 additions & 0 deletions keras/api/random/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from keras.src.random.random import beta
from keras.src.random.random import binomial
from keras.src.random.random import categorical
from keras.src.random.random import complex_uniform
from keras.src.random.random import dropout
from keras.src.random.random import gamma
from keras.src.random.random import initializer_for_complex
from keras.src.random.random import normal
from keras.src.random.random import randint
from keras.src.random.random import shuffle
Expand Down
31 changes: 25 additions & 6 deletions keras/src/layers/core/dense.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import ml_dtypes

import keras.src.backend as backend
from keras.src import activations
from keras.src import constraints
from keras.src import dtype_policies
from keras.src import initializers
from keras.src import ops
from keras.src import quantizers
from keras.src import random
from keras.src import regularizers
from keras.src.api_export import keras_export
from keras.src.layers.input_spec import InputSpec
Expand Down Expand Up @@ -84,6 +86,7 @@ def __init__(
lora_rank=None,
**kwargs,
):
self._compute_dtype = kwargs.pop("compute_dtype", None)
super().__init__(activity_regularizer=activity_regularizer, **kwargs)
self.units = units
self.activation = activations.get(activation)
Expand All @@ -104,12 +107,13 @@ def build(self, input_shape):
if self.quantization_mode:
self.quantized_build(input_shape, mode=self.quantization_mode)
if self.quantization_mode != "int8":
# If the layer is quantized to int8, `self._kernel` will be added
# in `self._int8_build`. Therefore, we skip it here.
wrapped_initializer = random.initializer_for_complex(
self.kernel_initializer
)
self._kernel = self.add_weight(
name="kernel",
shape=(input_dim, self.units),
initializer=self.kernel_initializer,
initializer=wrapped_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
)
Expand Down Expand Up @@ -141,9 +145,24 @@ def kernel(self):
return self._kernel

def call(self, inputs, training=None):
x = ops.matmul(inputs, self.kernel)
if self.bias is not None:
x = ops.add(x, self.bias)
compute_dtype = self._compute_dtype or backend.standardize_dtype(
inputs.dtype
)
input_dtype = backend.standardize_dtype(inputs.dtype)
# Added validation checks.
if self._compute_dtype is not None:
promoted_dtype = backend.result_type(compute_dtype, input_dtype)
inputs = ops.cast(inputs, promoted_dtype)
kernel = ops.cast(self.kernel, promoted_dtype)
x = ops.matmul(inputs, kernel)
if self.bias is not None:
bias = ops.cast(self.bias, promoted_dtype)
x = ops.add(x, bias)
# Fallback to old logic.
else:
x = ops.matmul(inputs, self.kernel)
if self.bias is not None:
x = ops.add(x, self.bias)
if self.activation is not None:
x = self.activation(x)
return x
Expand Down
88 changes: 88 additions & 0 deletions keras/src/layers/core/dense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,94 @@ def test_dense_basics(self):
supports_masking=True,
)

# Complex Tests:
def test_dense_complex_input_basic(self):
dense = layers.Dense(units=4)
complex_input = ops.convert_to_tensor(
np.array([[1 + 1j, 2 + 2j], [3 + 3j, 4 + 4j]], dtype=np.complex64)
)
output = dense(complex_input)
self.assertEqual(output.dtype, complex_input.dtype)
self.assertEqual(ops.shape(output), (2, 4))

def test_dense_complex_compute_dtype(self):
dense = layers.Dense(
units=3, dtype="complex64", compute_dtype="complex64"
)
complex_input = ops.convert_to_tensor(
np.array([[1 + 1j, 2 + 2j], [3 + 3j, 4 + 4j]], dtype=np.complex64)
)
output = dense(complex_input)
self.assertEqual(backend.standardize_dtype(output.dtype), "complex64")

def test_dense_complex_weights(self):
dense = layers.Dense(units=2)
complex_input = ops.convert_to_tensor(
np.array([[1 + 1j, 2 + 2j]], dtype=np.complex64)
)
dense.build(ops.shape(complex_input))
self.assertIsNotNone(dense.kernel)
self.assertIsNotNone(dense.bias)
output = dense(complex_input)
self.assertEqual(ops.shape(output), (1, 2))

@pytest.mark.requires_trainable_backend
def test_dense_complex_training(self):
dense = layers.Dense(units=1)
complex_input = ops.convert_to_tensor(
np.array([[1 + 1j]], dtype=np.complex64)
)
target = ops.convert_to_tensor(np.array([[2 + 2j]], dtype=np.complex64))

if backend.backend() == "tensorflow":
import tensorflow as tf

with tf.GradientTape() as tape:
output = dense(complex_input)
loss = ops.mean(ops.abs(output - target))
gradients = tape.gradient(loss, dense.trainable_variables)
elif backend.backend() == "jax":
import jax

dense.build((1,))

def stateless_loss(trainable_vars, x, y_true):
y_pred = dense.stateless_call(
trainable_vars, [], x, training=True
)[0]
return ops.mean(ops.abs(y_pred - y_true))

grad_fn = jax.grad(stateless_loss)
trainable_vars = [v.value for v in dense.trainable_variables]
gradients = grad_fn(trainable_vars, complex_input, target)
elif backend.backend() == "torch":
output = dense(complex_input)
loss = ops.mean(ops.abs(output - target))
loss.backward()
gradients = [v.value.grad for v in dense.trainable_variables]
else:
raise ValueError(f"Unsupported backend: {backend.backend()}")

self.assertIsNotNone(gradients)
self.assertTrue(all(g is not None for g in gradients))

def test_dense_mixed_dtype_computation(self):
dense = layers.Dense(units=2)
real_input = ops.convert_to_tensor(
np.array([[1.0, 2.0]], dtype=np.float32)
)
real_output = dense(real_input)
self.assertEqual(
backend.standardize_dtype(real_output.dtype), "float32"
)
complex_input = ops.convert_to_tensor(
np.array([[1 + 1j, 2 + 2j]], dtype=np.complex64)
)
complex_output = dense(complex_input)
self.assertEqual(
backend.standardize_dtype(complex_output.dtype), "complex64"
)

def test_dense_correctness(self):
# With bias and activation.
layer = layers.Dense(units=2, activation="relu")
Expand Down
2 changes: 2 additions & 0 deletions keras/src/random/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from keras.src.random.random import categorical
from keras.src.random.random import complex_uniform
from keras.src.random.random import dropout
from keras.src.random.random import gamma
from keras.src.random.random import initializer_for_complex
from keras.src.random.random import normal
from keras.src.random.random import randint
from keras.src.random.random import shuffle
Expand Down
99 changes: 99 additions & 0 deletions keras/src/random/random.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numpy as np

from keras.src import backend
from keras.src.api_export import keras_export

Expand Down Expand Up @@ -128,6 +130,103 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
)


@keras_export("keras.random.complex_uniform")
def complex_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
"""Draw random complex numbers from a uniform distribution.

The generated values follow a uniform distribution for both real and
imaginary components in the range `[minval, maxval)`. The lower bound
`minval` is included in the range, while the upper bound `maxval` is
excluded.

Args:
shape: The shape of the random values to generate.
minval: Float, defaults to 0. Lower bound of the range of
random values to generate (inclusive) for real and imaginary parts.
maxval: Float, defaults to 1. Upper bound of the range of
random values to generate (exclusive) for real and imaginary parts.
dtype: Optional dtype of the tensor. Only complex types are
supported ('complex64' or 'complex128'). If not specified,
'complex64' will be used.
seed: Optional Python integer or instance of
`keras.random.SeedGenerator`.
Used for both real and imaginary parts with different folded seeds.

Returns:
A complex tensor of the specified shape and dtype.

Raises:
ValueError: If dtype is not complex64 or complex128.
"""
if dtype is None:
dtype = "complex64"
if dtype not in ("complex64", "complex128"):
raise ValueError(f"dtype must be complex64 or complex128, got {dtype}")

float_dtype = "float32" if dtype == "complex64" else "float64"

if seed is not None:
if isinstance(seed, int):
seed = backend.random.SeedGenerator(seed)
seed_real = seed.next()
seed_imag = seed.next()
else:
seed_real = None
seed_imag = None

real_part = backend.random.uniform(
shape=shape,
minval=minval,
maxval=maxval,
dtype=float_dtype,
seed=seed_real,
)
imag_part = backend.random.uniform(
shape=shape,
minval=minval,
maxval=maxval,
dtype=float_dtype,
seed=seed_imag,
)

return backend.cast(real_part, dtype) + backend.cast(imag_part, dtype) * 1j


@keras_export("keras.random.initializer_for_complex")
def initializer_for_complex(initializer):
"""Modifies an initializer to handle complex dtypes.

When the requested dtype is complex, this wrapper modifies the initializer
to generate complex values using complex_uniform with an appropriate scale.
Otherwise, it delegates to the original initializer.

Args:
initializer: A Keras initializer function to be wrapped.

Returns:
A wrapped initializer that handles complex dtypes.
"""

def wrapped_initializer(shape, dtype=None, **kwargs):
dtype = backend.standardize_dtype(dtype)
if dtype is not None and dtype in ("complex64", "complex128"):
fan_in = float(shape[-1])
limit = backend.cast(1.0 / np.sqrt(fan_in), dtype="float32")
return complex_uniform(
shape, minval=-limit, maxval=limit, dtype=dtype, **kwargs
)
return initializer(shape, dtype=dtype, **kwargs)

if hasattr(initializer, "__name__"):
wrapped_initializer.__name__ = f"complex_{initializer.__name__}"
if hasattr(initializer, "__doc__"):
wrapped_initializer.__doc__ = (
f"Complex number variant of:\n\n{initializer.__doc__}"
)

return wrapped_initializer


@keras_export("keras.random.randint")
def randint(shape, minval, maxval, dtype="int32", seed=None):
"""Draw random integers from a uniform distribution.
Expand Down