Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional API updates #103

Open
wants to merge 3 commits into
base: redesign-staging
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions uvdat/core/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,7 @@ def get_dataset_name(self, obj):


class NetworkAdmin(admin.ModelAdmin):
list_display = ['id', 'category', 'get_dataset_name']

def get_dataset_name(self, obj):
return obj.dataset.name
list_display = ['id', 'category']


class NetworkEdgeAdmin(admin.ModelAdmin):
Expand Down
2 changes: 2 additions & 0 deletions uvdat/core/rest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .data import RasterDataViewSet, VectorDataViewSet
from .dataset import DatasetViewSet
from .layer import LayerFrameViewSet, LayerViewSet
from .networks import NetworkViewSet
from .project import ProjectViewSet
from .regions import RegionViewSet
from .simulations import SimulationViewSet
Expand All @@ -15,6 +16,7 @@
RasterDataViewSet,
VectorDataViewSet,
DatasetViewSet,
NetworkViewSet,
RegionViewSet,
SimulationViewSet,
UserViewSet,
Expand Down
61 changes: 7 additions & 54 deletions uvdat/core/rest/dataset.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,16 @@
from drf_yasg.utils import swagger_auto_schema
from rest_framework import serializers
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.viewsets import ReadOnlyModelViewSet

from uvdat.core.models import Dataset, Network, NetworkEdge, NetworkNode
from uvdat.core.models import Dataset
from uvdat.core.rest.access_control import GuardianFilter, GuardianPermission
from uvdat.core.rest.serializers import (
DatasetSerializer,
LayerSerializer,
NetworkEdgeSerializer,
NetworkNodeSerializer,
NetworkSerializer,
RasterDataSerializer,
VectorDataSerializer,
)
from uvdat.core.tasks.chart import add_gcc_chart_datum


class GCCQueryParamSerializer(serializers.Serializer):
project = serializers.IntegerField()
exclude_nodes = serializers.RegexField(r'^\d+(,\s?\d+)*$')


class DatasetViewSet(ReadOnlyModelViewSet):
Expand Down Expand Up @@ -56,47 +47,9 @@ def data(self, request, **kwargs):
return Response(data, status=200)

@action(detail=True, methods=['get'])
def network(self, request, **kwargs):
dataset = self.get_object()
networks = []
for network in dataset.get_networks().all():
networks.append(
{
'nodes': [
NetworkNodeSerializer(n).data
for n in NetworkNode.objects.filter(network=network)
],
'edges': [
NetworkEdgeSerializer(e).data
for e in NetworkEdge.objects.filter(network=network)
],
}
)
return Response(networks, status=200)

@swagger_auto_schema(query_serializer=GCCQueryParamSerializer)
@action(detail=True, methods=['get'])
def gcc(self, request, **kwargs):
def networks(self, request, **kwargs):
dataset = self.get_object()

# Validate and de-serialize query params
serializer = GCCQueryParamSerializer(data=request.query_params)
serializer.is_valid(raise_exception=True)
project_id = serializer.validated_data['project']
exclude_nodes = [int(n) for n in serializer.validated_data['exclude_nodes'].split(',')]

if not dataset.get_networks().exists():
return Response(data='No networks exist in selected dataset', status=400)

# Find the GCC for each network in the dataset
network_gccs: list[list[int]] = []
for network in dataset.get_networks().all():
network: Network
network_gccs.append(network.get_gcc(excluded_nodes=exclude_nodes))

# TODO: improve this for datasets with multiple networks.
# This currently returns the gcc for the network with the most excluded nodes
gcc = max(network_gccs, key=len)

add_gcc_chart_datum(dataset, project_id, exclude_nodes, len(gcc))
return Response(gcc, status=200)
return Response(
[NetworkSerializer(network).data for network in dataset.get_networks().all()],
status=200,
)
34 changes: 34 additions & 0 deletions uvdat/core/rest/networks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from drf_yasg.utils import swagger_auto_schema
from rest_framework import serializers
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet

from uvdat.core.models import Network
from uvdat.core.rest.access_control import GuardianFilter, GuardianPermission
from uvdat.core.rest.serializers import NetworkSerializer


class GCCQueryParamSerializer(serializers.Serializer):
exclude_nodes = serializers.RegexField(r'^\d+(,\s?\d+)*$')


class NetworkViewSet(ModelViewSet):
queryset = Network.objects.all()
serializer_class = NetworkSerializer
permission_classes = [GuardianPermission]
filter_backends = [GuardianFilter]
lookup_field = 'id'

@swagger_auto_schema(query_serializer=GCCQueryParamSerializer)
@action(detail=True, methods=['get'])
def gcc(self, request, **kwargs):
network = self.get_object()

# Validate and de-serialize query params
serializer = GCCQueryParamSerializer(data=request.query_params)
serializer.is_valid(raise_exception=True)
exclude_nodes = [int(n) for n in serializer.validated_data['exclude_nodes'].split(',')]

gcc = network.get_gcc(excluded_nodes=exclude_nodes)
return Response(gcc, status=200)
Comment on lines +33 to +34
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably best to use a serializer here, but I understand this is just relocated code. Not a huge deal imo.

15 changes: 9 additions & 6 deletions uvdat/core/rest/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,6 @@ def get_dump_object(self, obj):
return val


class NetworkSerializer(serializers.ModelSerializer):
class Meta:
model = Network
fields = '__all__'


class NetworkNodeSerializer(serializers.ModelSerializer):
class Meta:
model = NetworkNode
Expand All @@ -160,6 +154,15 @@ class Meta:
fields = '__all__'


class NetworkSerializer(serializers.ModelSerializer):
nodes = NetworkNodeSerializer(many=True, read_only=True)
edges = NetworkEdgeSerializer(many=True, read_only=True)

class Meta:
model = Network
fields = '__all__'


class SimulationResultSerializer(serializers.ModelSerializer):
name = serializers.SerializerMethodField('get_name')

Expand Down
3 changes: 3 additions & 0 deletions uvdat/core/rest/tokenauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

class TokenAuth(OAuth2Authentication):
def authenticate(self, request):
auth = super().authenticate(request)
if auth is not None:
return auth
token = request.query_params.get('token')
if token is None:
token_string = request.headers.get('Authorization')
Expand Down
1 change: 1 addition & 0 deletions uvdat/core/tasks/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def create_network(vector_data, network_options):
dataset = vector_data.dataset
Network.objects.filter(vector_data__dataset=dataset).delete()
network = Network.objects.create(
name=dataset.name + ' Network',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we currently constraining Networks to be one-to-one with datasets?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, do you think that we should add numbering to this name string to account for multi-network datasets?

category=dataset.category,
vector_data=vector_data,
metadata={'source': 'Parsed from GeoJSON.'},
Expand Down
97 changes: 1 addition & 96 deletions uvdat/core/tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import itertools

import pytest

from uvdat.core.models.networks import Network, NetworkNode
from uvdat.core.models.project import Dataset, Project
from uvdat.core.models.project import Dataset


@pytest.mark.django_db
Expand All @@ -22,74 +19,6 @@ def test_rest_dataset_list_retrieve(authenticated_api_client, dataset: Dataset):
assert resp.json()['id'] == dataset.id


@pytest.mark.django_db
def test_rest_dataset_gcc_no_networks(authenticated_api_client, dataset: Dataset, project: Project):
project.datasets.add(dataset)
resp = authenticated_api_client.get(
f'/api/v1/datasets/{dataset.id}/gcc/?project={project.id}&exclude_nodes=1'
)
assert resp.status_code == 400


@pytest.mark.django_db
def test_rest_dataset_gcc_empty_network(
authenticated_api_client, project: Project, network: Network
):
dataset = network.vector_data.dataset
project.datasets.add(dataset)
resp = authenticated_api_client.get(
f'/api/v1/datasets/{dataset.id}/gcc/?project={project.id}&exclude_nodes=1'
)

assert resp.status_code == 200
assert resp.json() == []


@pytest.mark.parametrize('group_sizes', [(3, 2), (20, 3)])
@pytest.mark.django_db
def test_rest_dataset_gcc(
authenticated_api_client,
project: Project,
network: Network,
network_edge_factory,
network_node_factory,
group_sizes,
):
group_a_size, group_b_size = group_sizes

# Create two groups of nodes that fully connected
group_a = [network_node_factory(network=network) for _ in range(group_a_size)]
for from_node, to_node in itertools.combinations(group_a, 2):
network_edge_factory(network=network, from_node=from_node, to_node=to_node)

group_b = [network_node_factory(network=network) for _ in range(group_b_size)]
for from_node, to_node in itertools.combinations(group_b, 2):
network_edge_factory(network=network, from_node=from_node, to_node=to_node)

# Join these two groups by a single node
connecting_node: NetworkNode = network_node_factory(network=network)
network_edge_factory(network=network, from_node=group_a[0], to_node=connecting_node)
network_edge_factory(network=network, from_node=group_b[0], to_node=connecting_node)

# Network should look like this
# * *
# | |
# * ---- * ---- *
# |
# *

dataset = network.vector_data.dataset
project.datasets.add(dataset)
resp = authenticated_api_client.get(
f'/api/v1/datasets/{dataset.id}/gcc/'
f'?project={project.id}&exclude_nodes={connecting_node.id}'
)

larger_group: list[NetworkNode] = max(group_a, group_b, key=len)
assert resp.status_code == 200
assert sorted(resp.json()) == sorted([n.id for n in larger_group])


@pytest.mark.django_db
def test_rest_dataset_layers(
authenticated_api_client,
Expand Down Expand Up @@ -131,27 +60,3 @@ def test_rest_dataset_data_objects(
print(data)
# Assert these lists are the same objects
assert sorted([x['id'] for x in data]) == sorted([x.id for x in data_objects])


@pytest.mark.django_db
def test_rest_dataset_network_no_network(authenticated_api_client, dataset: Dataset):
resp = authenticated_api_client.get(f'/api/v1/datasets/{dataset.id}/network/')
assert resp.status_code == 200
assert not resp.json()


@pytest.mark.django_db
def test_rest_dataset_network(authenticated_api_client, network_edge):
network = network_edge.network
dataset = network.vector_data.dataset
assert network_edge.from_node != network_edge.to_node

resp = authenticated_api_client.get(f'/api/v1/datasets/{dataset.id}/network/')
assert resp.status_code == 200

data: list[dict] = resp.json()
assert len(data) == 1

data: dict = data[0]
assert len(data['nodes']) == 2
assert len(data['edges']) == 1
90 changes: 90 additions & 0 deletions uvdat/core/tests/test_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import itertools

import pytest

from uvdat.core.models import Dataset, Network, NetworkNode, Project


@pytest.mark.django_db
def test_rest_dataset_networks_no_network(
authenticated_api_client, dataset: Dataset, project: Project
):
project.datasets.add(dataset)
resp = authenticated_api_client.get(f'/api/v1/datasets/{dataset.id}/networks/')
assert resp.status_code == 200
assert not resp.json()


@pytest.mark.django_db
def test_rest_dataset_networks(authenticated_api_client, project: Project, network_edge):
network = network_edge.network
dataset = network.vector_data.dataset
project.datasets.add(dataset)
assert network_edge.from_node != network_edge.to_node

resp = authenticated_api_client.get(f'/api/v1/datasets/{dataset.id}/networks/')
assert resp.status_code == 200

data: list[dict] = resp.json()
assert len(data) == 1

data: dict = data[0]
assert len(data['nodes']) == 2
assert len(data['edges']) == 1


@pytest.mark.django_db
def test_rest_network_gcc_empty(authenticated_api_client, user, project: Project, network: Network):
dataset = network.vector_data.dataset
project.set_owner(user)
project.datasets.add(dataset)
resp = authenticated_api_client.get(f'/api/v1/networks/{network.id}/gcc/?exclude_nodes=1')

assert resp.status_code == 200
assert resp.json() == []


@pytest.mark.parametrize('group_sizes', [(3, 2), (20, 3)])
@pytest.mark.django_db
def test_rest_network_gcc(
authenticated_api_client,
user,
project: Project,
network: Network,
network_edge_factory,
network_node_factory,
group_sizes,
):
dataset = network.vector_data.dataset
project.set_owner(user)
project.datasets.add(dataset)
group_a_size, group_b_size = group_sizes

# Create two groups of nodes that fully connected
group_a = [network_node_factory(network=network) for _ in range(group_a_size)]
for from_node, to_node in itertools.combinations(group_a, 2):
network_edge_factory(network=network, from_node=from_node, to_node=to_node)

group_b = [network_node_factory(network=network) for _ in range(group_b_size)]
for from_node, to_node in itertools.combinations(group_b, 2):
network_edge_factory(network=network, from_node=from_node, to_node=to_node)

# Join these two groups by a single node
connecting_node: NetworkNode = network_node_factory(network=network)
network_edge_factory(network=network, from_node=group_a[0], to_node=connecting_node)
network_edge_factory(network=network, from_node=group_b[0], to_node=connecting_node)

# Network should look like this
# * *
# | |
# * ---- * ---- *
# |
# *

resp = authenticated_api_client.get(
f'/api/v1/networks/{network.id}/gcc/?exclude_nodes={connecting_node.id}'
)

larger_group: list[NetworkNode] = max(group_a, group_b, key=len)
assert resp.status_code == 200
assert sorted(resp.json()) == sorted([n.id for n in larger_group])
Loading