Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add utils to communicate with openai #795

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion enterprise_catalog/apps/ai_curation/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ class AICurationSerializer(serializers.Serializer): # pylint: disable=abstract-
"""
Serializer for AI Curation.
"""
query = serializers.CharField()
query = serializers.CharField(max_length=300)
catalog_id = serializers.UUIDField()
17 changes: 17 additions & 0 deletions enterprise_catalog/apps/ai_curation/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
AI curation errors
"""

USER_MESSAGE = "Something went wrong. Please wait a minute and try again. If the issue persists, please reach out to your contact at edX." # pylint: disable=line-too-long


class AICurationError(Exception):
def __init__(self, message=USER_MESSAGE, dev_message=None, status_code=None):
super().__init__(message)
self.message = message
self.dev_message = dev_message or message
self.status_code = status_code


class InvalidJSONResponseError(AICurationError):
"""Invalid JSON response received"""
112 changes: 112 additions & 0 deletions enterprise_catalog/apps/ai_curation/openai_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import functools
import logging

import backoff
import simplejson
from django.conf import settings
from openai import (
APIConnectionError,
APIError,
APITimeoutError,
InternalServerError,
OpenAI,
RateLimitError,
)

from enterprise_catalog.apps.ai_curation.errors import (
AICurationError,
InvalidJSONResponseError,
)


LOGGER = logging.getLogger(__name__)

client = OpenAI(api_key=settings.OPENAI_API_KEY)


def api_error_handler(func):
"""
Decorator that activates when the API continues to raise persistent errors even after retries.

Raises a custom exception with the following attributes:
- message (str): A user-friendly message
- dev_message (str): The actual error message returned by the API
- status_code (int): The actual error code returned by the API
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except (APIError, AICurationError) as ex:
LOGGER.exception('[AI_CURATION] API Error: Prompt: [%s]', kwargs.get('messages'))
# status_code attribute is not available for all exceptions, such as APIConnectionError and APITimeoutError
status_code = getattr(ex, 'status_code', None)
message = getattr(ex, 'message', None)
raise AICurationError(dev_message=message, status_code=status_code) from ex
return wrapper


@api_error_handler
@backoff.on_exception(
backoff.expo,
(APIConnectionError, APITimeoutError, InternalServerError, RateLimitError, InvalidJSONResponseError),
max_tries=3,
)
def chat_completions(
messages,
response_format='json',
response_type=list,
model="gpt-4",
temperature=0.3,
max_tokens=500,
):
"""
Get a response from the chat.completions endpoint

Args:
messages (list): List of messages to send to the chat.completions endpoint
response_format (str): Format of the response. Can be 'json' or 'text'
response_type (any): Expected type of the response. For now we only expect `list`
model (str): Model to use for the completion
temperature (number): Make model output more focused and deterministic
max_tokens (int): Maximum number of tokens that can be generated in the chat completion

Returns:
list: The response from the chat.completions endpoint

Throws:
AICurationError: Raise an exception with the below attributes
- message (str): A user-friendly message
- dev_message (str): The actual error message returned by the API
- status_code (int): The actual error code returned by the API
"""
LOGGER.info('[AI_CURATION] [CHAT_COMPLETIONS] Prompt: [%s]', messages)
response = client.chat.completions.create(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
)
LOGGER.info('[AI_CURATION] [CHAT_COMPLETIONS] Response: [%s]', response)
response_content = response.choices[0].message.content

if response_format == 'json':
try:
json_response = simplejson.loads(response_content)
if isinstance(json_response, response_type):
return json_response
LOGGER.error(
'[AI_CURATION] JSON response received but response type is incorrect: Prompt: [%s], Response: [%s]',
messages,
response
)
raise InvalidJSONResponseError('Invalid response type received from chatgpt')
except simplejson.errors.JSONDecodeError as ex:
LOGGER.error(
'[AI_CURATION] Invalid JSON response received from chatgpt: Prompt: [%s], Response: [%s]',
messages,
response
)
raise InvalidJSONResponseError('Invalid JSON response received from chatgpt') from ex

return response_content

Check warning on line 112 in enterprise_catalog/apps/ai_curation/openai_client.py

View check run for this annotation

Codecov / codecov/patch

enterprise_catalog/apps/ai_curation/openai_client.py#L112

Added line #L112 was not covered by tests
232 changes: 231 additions & 1 deletion enterprise_catalog/apps/ai_curation/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
"""
Tests for the views of the ai_curation app.
Tests for ai_curation app utils.
"""
import json
import logging
from unittest import mock
from unittest.mock import MagicMock, patch

import httpx
from django.conf import settings
from django.test import TestCase
from openai import APIConnectionError

from enterprise_catalog.apps.ai_curation.errors import AICurationError
from enterprise_catalog.apps.ai_curation.utils import (
chat_completions,
fetch_catalog_metadata_from_algolia,
get_filtered_subjects,
get_keywords_to_prose,
get_query_keywords,
)


CHAT_COMPLETIONS_API_KEYWARGS = dict(
model='gpt-4', temperature=0.3, max_tokens=500
)


Expand Down Expand Up @@ -50,3 +66,217 @@ def test_fetch_catalog_metadata_from_algolia(self, mock_algolia_client):
self.assertEqual(sorted(subjects), ["Business & Management", "Computer Science",
"Data Analysis & Statistics", "Economics & Finance",
"Electronics", "Engineering", "Philosophy & Ethics"])


class TestChatCompletionUtils(TestCase):
@patch('enterprise_catalog.apps.ai_curation.utils.LOGGER')
@patch('enterprise_catalog.apps.ai_curation.openai_client.client.chat.completions.create')
def test_get_filtered_subjects(self, mock_create, mock_logger):
"""
Test that get_filtered_subjects returns the correct filtered subjects
"""
mock_create.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content=json.dumps(['subject1', 'subject2'])))]
)
subjects = ['subject1', 'subject2', 'subject3', 'subject4']
query = 'test query'
expected_content = settings.AI_CURATION_FILTER_SUBJECTS_PROMPT.format(query=query, subjects=subjects)

result = get_filtered_subjects(query, subjects)

mock_create.assert_called_once_with(
messages=[{'role': 'system', 'content': expected_content}], **CHAT_COMPLETIONS_API_KEYWARGS
)
mock_logger.info.assert_has_calls(
[
mock.call(
'[AI_CURATION] Filtering subjects. Prompt: [%s]',
[{'role': 'system', 'content': expected_content}]
),
mock.call('[AI_CURATION] Filtering subjects. Response: [%s]', ['subject1', 'subject2'])
]
)
assert result == ['subject1', 'subject2']

@patch('enterprise_catalog.apps.ai_curation.openai_client.LOGGER')
@patch('enterprise_catalog.apps.ai_curation.openai_client.client.chat.completions.create')
def test_invalid_json(self, mock_create, mock_logger):
"""
Test that correct exception is raised if chat.completions.create send an invalid json
"""
mock_create.return_value = MagicMock(choices=[MagicMock(message=MagicMock(content='non json response'))])

messages = [
{
'role': 'system',
'content': 'I am a prompt'
}
]
with self.assertRaises(AICurationError):
chat_completions(messages)

assert mock_create.call_count == 3
assert mock_logger.error.called
mock_logger.error.assert_has_calls([
mock.call(
'[AI_CURATION] Invalid JSON response received from chatgpt: Prompt: [%s], Response: [%s]',
[{'role': 'system', 'content': 'I am a prompt'}],
mock.ANY
),
mock.call(
'[AI_CURATION] Invalid JSON response received from chatgpt: Prompt: [%s], Response: [%s]',
[{'role': 'system', 'content': 'I am a prompt'}],
mock.ANY
),
mock.call(
'[AI_CURATION] Invalid JSON response received from chatgpt: Prompt: [%s], Response: [%s]',
[{'role': 'system', 'content': 'I am a prompt'}],
mock.ANY
)
])

@patch('enterprise_catalog.apps.ai_curation.openai_client.LOGGER')
@patch('enterprise_catalog.apps.ai_curation.openai_client.client.chat.completions.create')
def test_valid_json_with_wrong_type(self, mock_create, mock_logger):
"""
Test that correct exception is raised if chat.completions.create send a valid json but wrong type
"""
mock_create.return_value = MagicMock(choices=[MagicMock(message=MagicMock(content='{"a": 1}'))])

messages = [
{
'role': 'system',
'content': 'I am a prompt'
}
]
with self.assertRaises(AICurationError):
chat_completions(messages)

assert mock_create.call_count == 3
assert mock_logger.error.called
mock_logger.error.assert_has_calls([
mock.call(
'[AI_CURATION] JSON response received but response type is incorrect: Prompt: [%s], Response: [%s]',
[{'role': 'system', 'content': 'I am a prompt'}],
mock.ANY
),
mock.call(
'[AI_CURATION] JSON response received but response type is incorrect: Prompt: [%s], Response: [%s]',
[{'role': 'system', 'content': 'I am a prompt'}],
mock.ANY
),
mock.call(
'[AI_CURATION] JSON response received but response type is incorrect: Prompt: [%s], Response: [%s]',
[{'role': 'system', 'content': 'I am a prompt'}],
mock.ANY
)
])

@patch('enterprise_catalog.apps.ai_curation.utils.LOGGER')
@patch('enterprise_catalog.apps.ai_curation.openai_client.client.chat.completions.create')
def test_get_query_keywords(self, mock_create, mock_logger):
"""
Test that get_query_keywords returns the correct keywords
"""
mock_create.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content=json.dumps(['keyword1', 'keyword2'])))]
)
query = 'test query'
expected_content = settings.AI_CURATION_QUERY_TO_KEYWORDS_PROMPT.format(query=query)

result = get_query_keywords(query)

mock_create.assert_called_once_with(
messages=[{'role': 'system', 'content': expected_content}], **CHAT_COMPLETIONS_API_KEYWARGS
)
mock_logger.info.assert_has_calls(
[
mock.call(
'[AI_CURATION] Generating keywords. Prompt: [%s]',
[{'role': 'system', 'content': expected_content}]
),
mock.call('[AI_CURATION] Generating keywords. Response: [%s]', ['keyword1', 'keyword2'])
]
)
assert result == ['keyword1', 'keyword2']

@patch('enterprise_catalog.apps.ai_curation.utils.LOGGER')
@patch('enterprise_catalog.apps.ai_curation.openai_client.client.chat.completions.create')
@patch('enterprise_catalog.apps.ai_curation.utils.get_query_keywords')
def test_get_keywords_to_prose(self, mock_get_query_keywords, mock_create, mock_logger):
"""
Test that get_keywords_to_prose returns the correct prose
"""
mock_get_query_keywords.return_value = ['keyword1', 'keyword2']
mock_create.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content=json.dumps(['I am a prose'])))]
)
query = 'test query'
keywords = ['keyword1', 'keyword2']
expected_content = settings.AI_CURATION_KEYWORDS_TO_PROSE_PROMPT.format(query=query, keywords=keywords)

result = get_keywords_to_prose(query)

mock_create.assert_called_once_with(
messages=[{'role': 'system', 'content': expected_content}], **CHAT_COMPLETIONS_API_KEYWARGS
)
mock_logger.info.assert_has_calls(
[
mock.call(
'[AI_CURATION] Generating prose from keywords. Prompt: [%s]',
[{'role': 'system', 'content': expected_content}]
),
mock.call('[AI_CURATION] Generating prose from keywords. Response: [%s]', ['I am a prose'])
]
)
assert result == 'I am a prose'

@patch('enterprise_catalog.apps.ai_curation.openai_client.LOGGER')
@patch('enterprise_catalog.apps.ai_curation.openai_client.client.chat.completions.create')
def test_chat_completions_retries(self, mock_create, mock_logger):
"""
Test that retries work as expected for chat_completions
"""
mock_create.side_effect = APIConnectionError(request=httpx.Request("GET", "https://api.example.com"))
messages = [
{
'role': 'system',
'content': 'I am a prompt'
}
]
with self.assertRaises(AICurationError):
backoff_logger = logging.getLogger('backoff')
with mock.patch.multiple(backoff_logger, info=mock.DEFAULT, error=mock.DEFAULT) as mock_backoff_logger:
chat_completions(messages=messages)

assert mock_create.call_count == 3
assert mock_backoff_logger['info'].call_count == 2
mock_backoff_logger['info'].assert_has_calls(
[
mock.call(
'Backing off %s(...) for %.1fs (%s)',
'chat_completions',
mock.ANY,
'openai.APIConnectionError: Connection error.'
),
mock.call(
'Backing off %s(...) for %.1fs (%s)',
'chat_completions',
mock.ANY,
'openai.APIConnectionError: Connection error.'
)
]
)
assert mock_backoff_logger['error'].call_count == 1
mock_backoff_logger['error'].assert_has_calls(
[
mock.call(
'Giving up %s(...) after %d tries (%s)',
'chat_completions',
3,
'openai.APIConnectionError: Connection error.'
)
]
)
assert mock_logger.exception.called
mock_logger.exception.assert_has_calls([mock.call('[AI_CURATION] API Error: Prompt: [%s]', messages)])
9 changes: 9 additions & 0 deletions enterprise_catalog/apps/ai_curation/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,12 @@ def test_post(self, mock_trigger_ai_curations):
self.assertEqual(response.data['status'], AICurationStatus.PENDING)

mock_trigger_ai_curations.delay.assert_called_once()

def test_post_with_query(self):
"""
Verify that the api returns error if query length is greater than 300 characters
"""
data = {'query': 'a' * 301, 'catalog_id': str(uuid4())}
response = self.client.post(self.url, data)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.json(), {'query': ['Ensure this field has no more than 300 characters.']})
Loading
Loading