Skip to content

Commit

Permalink
Merge pull request #783 from openedx/saleem-latif/ENT-8316
Browse files Browse the repository at this point in the history
feat: Added a new REST endpoint to get AI Curation.
  • Loading branch information
saleem-latif authored Feb 29, 2024
2 parents 3a2c6ce + 1fcb8a5 commit b7c34c0
Show file tree
Hide file tree
Showing 17 changed files with 238 additions and 0 deletions.
Empty file.
Empty file.
12 changes: 12 additions & 0 deletions enterprise_catalog/apps/ai_curation/api/serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
Serializers for the AI Curation API.
"""
from rest_framework import serializers


class AICurationSerializer(serializers.Serializer): # pylint: disable=abstract-method
"""
Serializer for AI Curation.
"""
query = serializers.CharField()
catalog_id = serializers.UUIDField()
22 changes: 22 additions & 0 deletions enterprise_catalog/apps/ai_curation/api/throttle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
Throttle classes for AI Curation API.
"""
from rest_framework.throttling import AnonRateThrottle


class PostAICurationThrottle(AnonRateThrottle):
rate = '10/minute'

def allow_request(self, request, view):
if request.method == "POST":
return super().allow_request(request, view)
return True


class GetAICurationThrottle(AnonRateThrottle):
rate = '1/second'

def allow_request(self, request, view):
if request.method == "GET":
return super().allow_request(request, view)
return True
12 changes: 12 additions & 0 deletions enterprise_catalog/apps/ai_curation/api/urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
URL definitions for AI Curations API.
"""
from django.urls import include, path

from enterprise_catalog.apps.ai_curation.api.v1 import urls as v1_urls


app_name = 'ai_curation'
urlpatterns = [
path('v1/', include(v1_urls)),
]
Empty file.
11 changes: 11 additions & 0 deletions enterprise_catalog/apps/ai_curation/api/v1/urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""
URL definitions for AI Curations API version 1.
"""
from django.urls import path

from enterprise_catalog.apps.ai_curation.api.v1.views import AICurationView


urlpatterns = [
path('ai-curation', AICurationView.as_view(), name='ai-curation'),
]
62 changes: 62 additions & 0 deletions enterprise_catalog/apps/ai_curation/api/v1/views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
REST API for AI Curation.
"""
import json
from uuid import UUID

from django_celery_results.models import TaskResult
from rest_framework import status
from rest_framework.response import Response
from rest_framework.views import APIView

from enterprise_catalog.apps.ai_curation.api.serializers import (
AICurationSerializer,
)
from enterprise_catalog.apps.ai_curation.api.throttle import (
GetAICurationThrottle,
PostAICurationThrottle,
)
from enterprise_catalog.apps.ai_curation.enums import AICurationStatus
from enterprise_catalog.apps.ai_curation.tasks import trigger_ai_curations


class AICurationView(APIView):
"""
View for AI Curation.
"""
authentication_classes = []
permission_classes = []
throttle_classes = [GetAICurationThrottle, PostAICurationThrottle]

def get(self, request):
"""
Return details (status and response) of the given task.
"""
task_id = request.GET.get('task_id', None)
if not task_id:
return Response({'error': 'task_id is required.'}, status=status.HTTP_400_BAD_REQUEST)
try:
curation_task = TaskResult.objects.get(task_id=UUID(task_id))
return Response({
'status': curation_task.status,
'result': json.loads(curation_task.result or '{}'),
})
except TaskResult.DoesNotExist:
return Response({'error': 'Task not found.'}, status=status.HTTP_404_NOT_FOUND)
except ValueError:
return Response({'error': 'Invalid task_id.'}, status=status.HTTP_400_BAD_REQUEST)

def post(self, request):
"""
Trigger the AI curation process.
This will first validate the payload, the following fields are required:
1. query (str): User query that was input in the search bar.
2. catalog_id (uuid): The catalog id for which the AI curation is being triggered.
If the payload is valid, it will trigger the `trigger_ai_curations` celery task and return the task_id.
"""
serializer = AICurationSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
task = trigger_ai_curations.delay(**serializer.validated_data)
return Response({'task_id': str(task.task_id), 'status': AICurationStatus.PENDING})
7 changes: 7 additions & 0 deletions enterprise_catalog/apps/ai_curation/apps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from django.apps import AppConfig


class AICurationConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'ai_curation'
default = False
15 changes: 15 additions & 0 deletions enterprise_catalog/apps/ai_curation/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
Enums for AI Curation.
"""
from django.db import models


class AICurationStatus(models.TextChoices):
"""
Enum for AI Curation status.
"""
PENDING = 'PENDING'
IN_PROGRESS = 'IN_PROGRESS'
COMPLETED = 'COMPLETED'
FAILED = 'FAILED'
CANCELLED = 'CANCELLED'
Empty file.
20 changes: 20 additions & 0 deletions enterprise_catalog/apps/ai_curation/tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
Definitions of Celery tasks for the ai_curation app.
"""
import uuid

from celery import shared_task
from celery_utils.logged_task import LoggedTask


@shared_task(
base=LoggedTask,
bind=True,
)
def trigger_ai_curations(self, query: str, catalog_id: uuid.UUID): # pylint: disable=unused-argument
"""
Triggers the AI curation process.
"""
# TODO: Implement the AI curation process here and return the response.
# TODO: Replace return value with correct data.
return {'query': query, 'catalog_id': str(catalog_id)}
Empty file.
17 changes: 17 additions & 0 deletions enterprise_catalog/apps/ai_curation/tests/factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
Factories for AI Curation app.
"""
from uuid import uuid4

import factory
from django_celery_results.models import TaskResult


class TaskResultFactory(factory.django.DjangoModelFactory):
"""
Test factory for the `AICurationTask` model
"""
class Meta:
model = TaskResult

task_id = factory.LazyFunction(uuid4)
57 changes: 57 additions & 0 deletions enterprise_catalog/apps/ai_curation/tests/test_views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
Tests for the views of the ai_curation app.
"""
from unittest.mock import MagicMock, patch
from uuid import uuid4

from django.test import Client, TestCase
from django.urls import reverse
from rest_framework import status

from enterprise_catalog.apps.ai_curation.api.throttle import (
GetAICurationThrottle,
)
from enterprise_catalog.apps.ai_curation.enums import AICurationStatus
from enterprise_catalog.apps.ai_curation.tests import factories


class TestAICurationView(TestCase):
def setUp(self):
"""
Set up the test data.
"""
super().setUp()
self.client = Client()
self.url = reverse('ai_curation:ai-curation')
self.task = factories.TaskResultFactory.create()

def test_get(self):
"""
Verify that the get method returns the correct data.
"""
GetAICurationThrottle.allow_request = MagicMock(return_value=True)
response = self.client.get(self.url, {'task_id': self.task.task_id})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['status'], self.task.status)

response = self.client.get(self.url, {'task_id': str(uuid4())})
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)

response = self.client.get(self.url)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

response = self.client.get(self.url, {'task_id': 'invalid'})
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

@patch('enterprise_catalog.apps.ai_curation.api.v1.views.trigger_ai_curations')
def test_post(self, mock_trigger_ai_curations):
"""
Verify that the job calls the trigger_ai_curations with the test data
"""
data = {'query': 'Give all courses from edX org.', 'catalog_id': str(uuid4())}
response = self.client.post(self.url, data)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('task_id', response.data)
self.assertEqual(response.data['status'], AICurationStatus.PENDING)

mock_trigger_ai_curations.delay.assert_called_once()
1 change: 1 addition & 0 deletions enterprise_catalog/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
'enterprise_catalog.apps.curation',
'enterprise_catalog.apps.api',
'enterprise_catalog.apps.academy',
'enterprise_catalog.apps.ai_curation',
)

INSTALLED_APPS += THIRD_PARTY_APPS
Expand Down
2 changes: 2 additions & 0 deletions enterprise_catalog/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
SpectacularSwaggerView,
)

from enterprise_catalog.apps.ai_curation.api import urls as ai_curation_urls
from enterprise_catalog.apps.api import urls as api_urls
from enterprise_catalog.apps.core import views as core_views

Expand All @@ -49,6 +50,7 @@
path('admin/clearcache/', include('clearcache.urls')),
path('admin/', admin.site.urls),
path('api/', include(api_urls), name='api'),
path('api/', include(ai_curation_urls), name='api'),
# Use the same auth views for all logins, including those originating from the browseable API.
path('auto_auth/', core_views.AutoAuth.as_view(), name='auto_auth'),
path('health/', core_views.health, name='health'),
Expand Down

0 comments on commit b7c34c0

Please sign in to comment.