diff --git a/tensorflow_recommenders/layers/embedding/tpu_embedding_layer.py b/tensorflow_recommenders/layers/embedding/tpu_embedding_layer.py index 83cad9db..fc190957 100644 --- a/tensorflow_recommenders/layers/embedding/tpu_embedding_layer.py +++ b/tensorflow_recommenders/layers/embedding/tpu_embedding_layer.py @@ -67,6 +67,12 @@ TPUEmbeddingType = ( TPUEmbeddingType | tf.tpu.experimental.embedding.TPUEmbeddingV2 ) +if hasattr(tf.tpu.experimental.embedding, "SparseCoreEmbeddingConfig"): + SparseCoreEmbeddingConfig = ( + tf.tpu.experimental.embedding.SparseCoreEmbeddingConfig + ) +else: + SparseCoreEmbeddingConfig = None # pylint: disable=invalid-name def _normalize_and_prepare_optimizer(optimizer): @@ -597,7 +603,8 @@ def __init__( tf.tpu.experimental.embedding.FTRL]], pipeline_execution_with_tensor_core: bool = False, batch_size: Optional[int] = None, - embedding_feature: Optional[EmbeddingFeature] = None): + embedding_feature: Optional[EmbeddingFeature] = None, + sparse_core_embedding_config: Optional[SparseCoreEmbeddingConfig] = None): """A Keras layer for accelerated embedding lookups on TPU. Args: @@ -617,6 +624,8 @@ def __init__( compatibility. embedding_feature: EmbeddingFeature enum, inidicating which version of TPU hardware the layer should run on. + sparse_core_embedding_config: SparseCoreEmbeddingConfig, inidicating + configuration for sparse core embedding when using TPUEmbedding V2 """ super().__init__() self._feature_config, self._table_config_map = ( @@ -654,6 +663,7 @@ def __init__( self._using_tpu, self._embedding_feature, pipeline_execution_with_tensor_core, + sparse_core_embedding_config ) self.batch_size = batch_size self._tpu_call_id = 0 @@ -663,6 +673,7 @@ def _create_tpu_embedding_mid_level_api( using_tpu: bool, embedding_feature: Optional[EmbeddingFeature], pipeline_execution_with_tensor_core: bool, + sparse_core_embedding_config: Optional[SparseCoreEmbeddingConfig], ) -> TPUEmbeddingType: """Creates TPU Embedding mid level API instance based on settings. @@ -674,6 +685,8 @@ def _create_tpu_embedding_mid_level_api( computations will overlap with the TensorCore computations (and hence will be one step old with potential correctness drawbacks). Only used when the embedding feature is set to be v1. + sparse_core_embedding_config: SparseCoreEmbeddingConfig used by TPU + ` Embedding V2 Returns: Instance of the TPUEmbedding mid level API. @@ -699,6 +712,7 @@ def _create_tpu_embedding_mid_level_api( self._feature_config, self._optimizer, pipeline_execution_with_tensor_core, + sparse_core_embedding_config, ) else: raise ValueError("TPUEmbeddingV2 is not supported in TF.")