Skip to content

Commit

Permalink
Patch for newer version of tensorflow. smart_cond changed location.
Browse files Browse the repository at this point in the history
Added docs on how to install and use.
  • Loading branch information
Stephen Tridgell committed Jan 17, 2022
1 parent b5f5db0 commit 8ea837e
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 14 deletions.
18 changes: 16 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,28 @@ Working in the docker image is annoying and this code should be standalone.

## Install

```
Build from source:

```bash
python3 setup.py bdist_wheel
pip install dist/vitis_quantizer-0.1.0-py3-none-any.whl
```

## Usage

TODO
Usage is the same as Vitis AI models.

```python
import tensorflow as tf
import vitis_quantizer

# Train/Get/Make a keras model somehow
model = tf.keras.models.load_model("/path/to/keras/model")
quantizer = vitis_quantizer.VitisQuantizer(model)
with vitis_quantizer.quantize_scope():
quantized_model = quantizer.quantize_model(calib_dataset=dataset)
quantized_model.save("/path/to/save/quantized/model")
```

After you have the quantized model saved, use vitis compile.sh script.

3 changes: 2 additions & 1 deletion vitis_quantizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
# limitations under the License.
# ==============================================================================

from vitis_quantizer.vitis_quantize import VitisQuantizer
from vitis_quantizer.vitis_quantize import VitisQuantizer
from vitis_quantizer.vitis_quantize import quantize_scope
6 changes: 3 additions & 3 deletions vitis_quantizer/common/vitis_quantize_aware_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import tensorflow as tf

# TODO(b/139939526): move to public API.
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.framework.smart_cond import smart_cond
from tensorflow.python.keras.utils.generic_utils import register_keras_serializable
from vitis_quantizer.utils import common_utils

Expand Down Expand Up @@ -210,7 +210,7 @@ def quantizer_fn():

x = inputs
if self._should_pre_quantize():
x = tf_utils.smart_cond(
x = smart_cond(
self._training,
make_quantizer_fn(
self.quantizer, x, True, self.mode, self._pre_activation_vars
Expand All @@ -223,7 +223,7 @@ def quantizer_fn():
x = self.activation(x, *args, **kwargs)

if self._should_post_quantize():
x = tf_utils.smart_cond(
x = smart_cond(
self._training,
make_quantizer_fn(
self.quantizer, x, True, self.mode, self._post_activation_vars
Expand Down
8 changes: 4 additions & 4 deletions vitis_quantizer/common/vitis_quantize_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import tensorflow as tf

# TODO(b/139939526): move to public API.
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.framework.smart_cond import smart_cond
from tensorflow.python.util import tf_inspect
from tensorflow.python.keras.utils.generic_utils import register_keras_serializable
from vitis_quantizer.common import vitis_quantize_aware_activation
Expand Down Expand Up @@ -222,7 +222,7 @@ def call(self, inputs, training=None):
# Quantize all weights, and replace them in the underlying layer.
quantized_weights = []
for unquantized_weight, quantizer, quantizer_vars in self._weight_vars:
quantized_weight = tf_utils.smart_cond(
quantized_weight = smart_cond(
training,
self._make_quantizer_fn(
quantizer, unquantized_weight, True, self.mode, quantizer_vars
Expand Down Expand Up @@ -262,7 +262,7 @@ def call(self, inputs, training=None):
output_quantizer,
output_quantizer_vars,
) in self._output_quantizer_vars:
quantized_outputs[output_id] = tf_utils.smart_cond(
quantized_outputs[output_id] = smart_cond(
training,
self._make_quantizer_fn(
output_quantizer,
Expand All @@ -287,7 +287,7 @@ def call(self, inputs, training=None):
output_quantizer,
output_quantizer_vars,
) = self._output_quantizer_vars[0]
return tf_utils.smart_cond(
return smart_cond(
training,
self._make_quantizer_fn(
output_quantizer, outputs, True, self.mode, output_quantizer_vars
Expand Down
6 changes: 2 additions & 4 deletions vitis_quantizer/layers/vitis_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import tensorflow as tf

from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.framework.smart_cond import smart_cond
from tensorflow.python.keras.utils.generic_utils import register_keras_serializable
from vitis_quantizer.base import quantizer as quantizer_mod
from vitis_quantizer.utils import common_utils
Expand Down Expand Up @@ -78,9 +78,7 @@ def quantizer_fn():

return quantizer_fn

return tf_utils.smart_cond(
training, _make_quantizer_fn(True), _make_quantizer_fn(False)
)
return smart_cond(training, _make_quantizer_fn(True), _make_quantizer_fn(False))

def get_config(self):
base_config = super(VitisQuantize, self).get_config()
Expand Down

0 comments on commit 8ea837e

Please sign in to comment.