Skip to content

Commit

Permalink
Wrap loading torch ops into a separate function. Do not raise error i…
Browse files Browse the repository at this point in the history
…f running from sphinx

Summary: This is necessary if we want to run building docs in github workflow

Differential Revision: D63497765

fbshipit-source-id: 018157c205a66584dd040882124588e499439893
  • Loading branch information
podgorskiy authored and facebook-github-bot committed Sep 27, 2024
1 parent b3060c7 commit dfe3fdf
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 17 deletions.
5 changes: 2 additions & 3 deletions drtk/edge_grad_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@

import torch as th
import torch.nn.functional as thf
from drtk import edge_grad_ext
from drtk.interpolate import interpolate
from drtk.utils import index
from drtk.utils import index, load_torch_ops


th.ops.load_library(edge_grad_ext.__file__)
load_torch_ops("drtk.edge_grad_ext")


@th.compiler.disable
Expand Down
4 changes: 2 additions & 2 deletions drtk/grid_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

import torch as th
import torch.nn.functional as thf
from drtk import grid_scatter_ext
from drtk.utils import load_torch_ops

th.ops.load_library(grid_scatter_ext.__file__)
load_torch_ops("drtk.grid_scatter_ext")


@th.compiler.disable
Expand Down
4 changes: 2 additions & 2 deletions drtk/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
"""

import torch as th
from drtk import interpolate_ext
from drtk.utils import load_torch_ops

th.ops.load_library(interpolate_ext.__file__)
load_torch_ops("drtk.interpolate_ext")


@th.compiler.disable
Expand Down
4 changes: 2 additions & 2 deletions drtk/mipmap_grid_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

import torch as th
import torch.nn.functional as thf
from drtk import mipmap_grid_sampler_ext
from drtk.utils import load_torch_ops

th.ops.load_library(mipmap_grid_sampler_ext.__file__)
load_torch_ops("drtk.mipmap_grid_sampler_ext")


@th.compiler.disable
Expand Down
4 changes: 2 additions & 2 deletions drtk/msi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
# LICENSE file in the root directory of this source tree.

import torch as th
from drtk import msi_ext
from drtk.utils import load_torch_ops

th.ops.load_library(msi_ext.__file__)
load_torch_ops("drtk.msi_ext")


@th.compiler.disable
Expand Down
5 changes: 2 additions & 3 deletions drtk/rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
from typing import Tuple

import torch as th
from drtk.utils import load_torch_ops

from drtk import rasterize_ext

th.ops.load_library(rasterize_ext.__file__)
load_torch_ops("drtk.rasterize_ext")


@th.compiler.disable
Expand Down
5 changes: 2 additions & 3 deletions drtk/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
from typing import Tuple

import torch as th
from drtk.utils import load_torch_ops

from drtk import render_ext

th.ops.load_library(render_ext.__file__)
load_torch_ops("drtk.render_ext")


@th.compiler.disable
Expand Down
1 change: 1 addition & 0 deletions drtk/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
vert_normals, # noqa
)
from drtk.utils.indexing import index # noqa
from drtk.utils.load_torch_ops import load_torch_ops # noqa
from drtk.utils.projection import ( # noqa
DISTORTION_MODES, # noqa
project_points, # noqa
Expand Down
24 changes: 24 additions & 0 deletions drtk/utils/load_torch_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import importlib

import torch as th


def load_torch_ops(extension: str) -> None:
try:
module = importlib.import_module(extension)
th.ops.load_library(module.__file__)
except ImportError as e:
import sys

# If running in sphinx, don't raise an error. That way we can build documentation without
# building extensions
if "sphinx" in sys.modules:
return

raise e

0 comments on commit dfe3fdf

Please sign in to comment.