-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #783 from openedx/saleem-latif/ENT-8316
feat: Added a new REST endpoint to get AI Curation.
- Loading branch information
Showing
17 changed files
with
238 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
""" | ||
Serializers for the AI Curation API. | ||
""" | ||
from rest_framework import serializers | ||
|
||
|
||
class AICurationSerializer(serializers.Serializer): # pylint: disable=abstract-method | ||
""" | ||
Serializer for AI Curation. | ||
""" | ||
query = serializers.CharField() | ||
catalog_id = serializers.UUIDField() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
""" | ||
Throttle classes for AI Curation API. | ||
""" | ||
from rest_framework.throttling import AnonRateThrottle | ||
|
||
|
||
class PostAICurationThrottle(AnonRateThrottle): | ||
rate = '10/minute' | ||
|
||
def allow_request(self, request, view): | ||
if request.method == "POST": | ||
return super().allow_request(request, view) | ||
return True | ||
|
||
|
||
class GetAICurationThrottle(AnonRateThrottle): | ||
rate = '1/second' | ||
|
||
def allow_request(self, request, view): | ||
if request.method == "GET": | ||
return super().allow_request(request, view) | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
""" | ||
URL definitions for AI Curations API. | ||
""" | ||
from django.urls import include, path | ||
|
||
from enterprise_catalog.apps.ai_curation.api.v1 import urls as v1_urls | ||
|
||
|
||
app_name = 'ai_curation' | ||
urlpatterns = [ | ||
path('v1/', include(v1_urls)), | ||
] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
""" | ||
URL definitions for AI Curations API version 1. | ||
""" | ||
from django.urls import path | ||
|
||
from enterprise_catalog.apps.ai_curation.api.v1.views import AICurationView | ||
|
||
|
||
urlpatterns = [ | ||
path('ai-curation', AICurationView.as_view(), name='ai-curation'), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
""" | ||
REST API for AI Curation. | ||
""" | ||
import json | ||
from uuid import UUID | ||
|
||
from django_celery_results.models import TaskResult | ||
from rest_framework import status | ||
from rest_framework.response import Response | ||
from rest_framework.views import APIView | ||
|
||
from enterprise_catalog.apps.ai_curation.api.serializers import ( | ||
AICurationSerializer, | ||
) | ||
from enterprise_catalog.apps.ai_curation.api.throttle import ( | ||
GetAICurationThrottle, | ||
PostAICurationThrottle, | ||
) | ||
from enterprise_catalog.apps.ai_curation.enums import AICurationStatus | ||
from enterprise_catalog.apps.ai_curation.tasks import trigger_ai_curations | ||
|
||
|
||
class AICurationView(APIView): | ||
""" | ||
View for AI Curation. | ||
""" | ||
authentication_classes = [] | ||
permission_classes = [] | ||
throttle_classes = [GetAICurationThrottle, PostAICurationThrottle] | ||
|
||
def get(self, request): | ||
""" | ||
Return details (status and response) of the given task. | ||
""" | ||
task_id = request.GET.get('task_id', None) | ||
if not task_id: | ||
return Response({'error': 'task_id is required.'}, status=status.HTTP_400_BAD_REQUEST) | ||
try: | ||
curation_task = TaskResult.objects.get(task_id=UUID(task_id)) | ||
return Response({ | ||
'status': curation_task.status, | ||
'result': json.loads(curation_task.result or '{}'), | ||
}) | ||
except TaskResult.DoesNotExist: | ||
return Response({'error': 'Task not found.'}, status=status.HTTP_404_NOT_FOUND) | ||
except ValueError: | ||
return Response({'error': 'Invalid task_id.'}, status=status.HTTP_400_BAD_REQUEST) | ||
|
||
def post(self, request): | ||
""" | ||
Trigger the AI curation process. | ||
This will first validate the payload, the following fields are required: | ||
1. query (str): User query that was input in the search bar. | ||
2. catalog_id (uuid): The catalog id for which the AI curation is being triggered. | ||
If the payload is valid, it will trigger the `trigger_ai_curations` celery task and return the task_id. | ||
""" | ||
serializer = AICurationSerializer(data=request.data) | ||
serializer.is_valid(raise_exception=True) | ||
task = trigger_ai_curations.delay(**serializer.validated_data) | ||
return Response({'task_id': str(task.task_id), 'status': AICurationStatus.PENDING}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from django.apps import AppConfig | ||
|
||
|
||
class AICurationConfig(AppConfig): | ||
default_auto_field = 'django.db.models.BigAutoField' | ||
name = 'ai_curation' | ||
default = False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
""" | ||
Enums for AI Curation. | ||
""" | ||
from django.db import models | ||
|
||
|
||
class AICurationStatus(models.TextChoices): | ||
""" | ||
Enum for AI Curation status. | ||
""" | ||
PENDING = 'PENDING' | ||
IN_PROGRESS = 'IN_PROGRESS' | ||
COMPLETED = 'COMPLETED' | ||
FAILED = 'FAILED' | ||
CANCELLED = 'CANCELLED' |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
""" | ||
Definitions of Celery tasks for the ai_curation app. | ||
""" | ||
import uuid | ||
|
||
from celery import shared_task | ||
from celery_utils.logged_task import LoggedTask | ||
|
||
|
||
@shared_task( | ||
base=LoggedTask, | ||
bind=True, | ||
) | ||
def trigger_ai_curations(self, query: str, catalog_id: uuid.UUID): # pylint: disable=unused-argument | ||
""" | ||
Triggers the AI curation process. | ||
""" | ||
# TODO: Implement the AI curation process here and return the response. | ||
# TODO: Replace return value with correct data. | ||
return {'query': query, 'catalog_id': str(catalog_id)} |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
""" | ||
Factories for AI Curation app. | ||
""" | ||
from uuid import uuid4 | ||
|
||
import factory | ||
from django_celery_results.models import TaskResult | ||
|
||
|
||
class TaskResultFactory(factory.django.DjangoModelFactory): | ||
""" | ||
Test factory for the `AICurationTask` model | ||
""" | ||
class Meta: | ||
model = TaskResult | ||
|
||
task_id = factory.LazyFunction(uuid4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
""" | ||
Tests for the views of the ai_curation app. | ||
""" | ||
from unittest.mock import MagicMock, patch | ||
from uuid import uuid4 | ||
|
||
from django.test import Client, TestCase | ||
from django.urls import reverse | ||
from rest_framework import status | ||
|
||
from enterprise_catalog.apps.ai_curation.api.throttle import ( | ||
GetAICurationThrottle, | ||
) | ||
from enterprise_catalog.apps.ai_curation.enums import AICurationStatus | ||
from enterprise_catalog.apps.ai_curation.tests import factories | ||
|
||
|
||
class TestAICurationView(TestCase): | ||
def setUp(self): | ||
""" | ||
Set up the test data. | ||
""" | ||
super().setUp() | ||
self.client = Client() | ||
self.url = reverse('ai_curation:ai-curation') | ||
self.task = factories.TaskResultFactory.create() | ||
|
||
def test_get(self): | ||
""" | ||
Verify that the get method returns the correct data. | ||
""" | ||
GetAICurationThrottle.allow_request = MagicMock(return_value=True) | ||
response = self.client.get(self.url, {'task_id': self.task.task_id}) | ||
self.assertEqual(response.status_code, status.HTTP_200_OK) | ||
self.assertEqual(response.data['status'], self.task.status) | ||
|
||
response = self.client.get(self.url, {'task_id': str(uuid4())}) | ||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) | ||
|
||
response = self.client.get(self.url) | ||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) | ||
|
||
response = self.client.get(self.url, {'task_id': 'invalid'}) | ||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) | ||
|
||
@patch('enterprise_catalog.apps.ai_curation.api.v1.views.trigger_ai_curations') | ||
def test_post(self, mock_trigger_ai_curations): | ||
""" | ||
Verify that the job calls the trigger_ai_curations with the test data | ||
""" | ||
data = {'query': 'Give all courses from edX org.', 'catalog_id': str(uuid4())} | ||
response = self.client.post(self.url, data) | ||
self.assertEqual(response.status_code, status.HTTP_200_OK) | ||
self.assertIn('task_id', response.data) | ||
self.assertEqual(response.data['status'], AICurationStatus.PENDING) | ||
|
||
mock_trigger_ai_curations.delay.assert_called_once() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters