diff --git a/README.md b/README.md index a78f996561..44d25088e2 100644 --- a/README.md +++ b/README.md @@ -394,7 +394,7 @@ access to a number of our interpretability algorithms. To analyze a sample model on CIFAR10 via Captum Insights run ``` -python -m captum.insights.example +python -m captum.insights.attr_vis.example ``` and navigate to the URL specified in the output. diff --git a/captum/insights/attr_vis/example.py b/captum/insights/attr_vis/example.py index be20e44c40..cb7c071b7c 100644 --- a/captum/insights/attr_vis/example.py +++ b/captum/insights/attr_vis/example.py @@ -10,9 +10,6 @@ import torchvision.transforms as transforms from captum.insights import AttributionVisualizer, Batch -# pyre-fixme[21]: Could not find module -# `captum.insights.attr_vis.example.get_pretrained_model`. -from captum.insights.attr_vis.example.get_pretrained_model import Net from captum.insights.attr_vis.features import ImageFeature @@ -32,31 +29,32 @@ def get_classes() -> List[str]: return classes -def get_pretrained_model() -> Net: - class Net(nn.Module): - def __init__(self) -> None: - super(Net, self).__init__() - self.conv1 = nn.Conv2d(3, 6, 5) - self.pool1 = nn.MaxPool2d(2, 2) - self.pool2 = nn.MaxPool2d(2, 2) - self.conv2 = nn.Conv2d(6, 16, 5) - self.fc1 = nn.Linear(16 * 5 * 5, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) - self.relu1 = nn.ReLU() - self.relu2 = nn.ReLU() - self.relu3 = nn.ReLU() - self.relu4 = nn.ReLU() - - def forward(self, x): - x = self.pool1(self.relu1(self.conv1(x))) - x = self.pool2(self.relu2(self.conv2(x))) - x = x.view(-1, 16 * 5 * 5) - x = self.relu3(self.fc1(x)) - x = self.relu4(self.fc2(x)) - x = self.fc3(x) - return x +class Net(nn.Module): + def __init__(self) -> None: + super(Net, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool1 = nn.MaxPool2d(2, 2) + self.pool2 = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + self.relu1 = nn.ReLU() + self.relu2 = nn.ReLU() + self.relu3 = nn.ReLU() + self.relu4 = nn.ReLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.pool1(self.relu1(self.conv1(x))) + x = self.pool2(self.relu2(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = self.relu3(self.fc1(x)) + x = self.relu4(self.fc2(x)) + x = self.fc3(x) + return x + +def get_pretrained_model() -> Net: net = Net() pt_path = os.path.abspath( os.path.join(os.path.dirname(__file__), "models/cifar_torchvision.pt") diff --git a/captum/insights/attr_vis/server.py b/captum/insights/attr_vis/server.py index 6d13d81835..5edbd0eb26 100644 --- a/captum/insights/attr_vis/server.py +++ b/captum/insights/attr_vis/server.py @@ -2,7 +2,6 @@ # pyre-strict import logging -import os import socket import threading from time import sleep @@ -108,7 +107,6 @@ def start_server( global port if port is None: - os.environ["WERKZEUG_RUN_MAIN"] = "true" # hides starting message if not debug: log = logging.getLogger("werkzeug") log.disabled = True