diff --git a/tensorflow_graphics/projects/rimls/README.md b/tensorflow_graphics/projects/rimls/README.md new file mode 100644 index 000000000..25d65e241 --- /dev/null +++ b/tensorflow_graphics/projects/rimls/README.md @@ -0,0 +1,3 @@ +# Robust Implicit MLS + +[paper](https://hal.inria.fr/inria-00354969/document) diff --git a/tensorflow_graphics/projects/rimls/utils.py b/tensorflow_graphics/projects/rimls/utils.py new file mode 100644 index 000000000..f1c666d89 --- /dev/null +++ b/tensorflow_graphics/projects/rimls/utils.py @@ -0,0 +1,33 @@ +# Copyright 2020 The TensorFlow Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility functions.""" +import tensorflow as tf + + +class NearestNeighbors(object): + """API class for nearest neighbors.""" + + def __init__(self, points, k): + tf.debugging.assert_less_equal(k, points.shape[0]) + self.k = k + self.points = points + + def kneighbors(self, queries): + expanded_queries = tf.expand_dims(queries, 1) + expanded_points = tf.expand_dims(self.points, 0) + dist = (expanded_queries - expanded_points) ** 2 + dist = tf.reduce_sum(dist, -1) + dist = tf.sqrt(dist) + sorted_indices = tf.argsort(dist, axis=-1) + return sorted_indices[..., :self.k] diff --git a/tensorflow_graphics/projects/rimls/utils_test.py b/tensorflow_graphics/projects/rimls/utils_test.py new file mode 100644 index 000000000..3aab1a7bc --- /dev/null +++ b/tensorflow_graphics/projects/rimls/utils_test.py @@ -0,0 +1,46 @@ +# Copyright 2020 The TensorFlow Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for utils.""" + +import tensorflow as tf +from tensorflow_graphics.projects.rimls import utils +from google3.testing.pybase import googletest + + +class TestNearestNeighbors(googletest.TestCase): + + def test_correct_neighbors(self): + points = tf.convert_to_tensor([ + [0, 0], + [1, 0], + [1, 1], + ], dtype=tf.float32) + + queries = tf.convert_to_tensor([ + [0, 1], + [1, 0], + [1, 1.5], + ], dtype=tf.float32) + + nn_indices = utils.NearestNeighbors(points, k=2).kneighbors(queries) + # Neighbors need not be ordered. + nn_indices = tf.sort(nn_indices).numpy().tolist() + + self.assertEqual(nn_indices[0], [0, 2]) + self.assertTrue(nn_indices[1] == [0, 1] or nn_indices[1] == [1, 2]) + self.assertEqual(nn_indices[2], [1, 2]) + + +if __name__ == '__main__': + googletest.main()