Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: integrate system prompt #47

Merged
merged 2 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@ Change Log

Unreleased
**********

2.0.1 - 2021-01-08
******************
* Gate content integration with waffle flag

2.0.0 - 2024-01-03
******************
* Add content cache
* Integrate system prompt setting

1.5.0 - 2023-10-18
******************
Expand Down
2 changes: 1 addition & 1 deletion learning_assistant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
Plugin for a learning assistant backend, intended for use within edx-platform.
"""

__version__ = '1.5.0'
__version__ = '2.0.1'

default_app_config = 'learning_assistant.apps.LearningAssistantConfig' # pylint: disable=invalid-name
23 changes: 23 additions & 0 deletions learning_assistant/api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""
Library for the learning_assistant app.
"""
import logging

from django.conf import settings
from django.core.cache import cache
from edx_django_utils.cache import get_cache_key
from jinja2 import BaseLoader, Environment
from opaque_keys.edx.keys import CourseKey

from learning_assistant.constants import ACCEPTED_CATEGORY_TYPES, CATEGORY_TYPE_MAP
from learning_assistant.models import CoursePrompt
Expand All @@ -15,6 +19,9 @@
traverse_block_pre_order,
)
from learning_assistant.text_utils import html_to_text
from learning_assistant.toggles import course_content_enabled

log = logging.getLogger(__name__)


def get_deserialized_prompt_content_by_course_id(course_id):
Expand Down Expand Up @@ -112,3 +119,19 @@ def get_block_content(request, user_id, course_id, unit_usage_key):
cache.set(cache_key, cache_data, getattr(settings, 'LEARNING_ASSISTANT_CACHE_TIMEOUT', 360))

return cache_data['content_length'], cache_data['content_items']


def render_prompt_template(request, user_id, course_id, unit_usage_key):
"""
Return a rendered prompt template, specified by the LEARNING_ASSISTANT_PROMPT_TEMPLATE setting.
"""
unit_content = ''

course_run_key = CourseKey.from_string(course_id)
if unit_usage_key and course_content_enabled(course_run_key):
_, unit_content = get_block_content(request, user_id, course_id, unit_usage_key)

template_string = getattr(settings, 'LEARNING_ASSISTANT_PROMPT_TEMPLATE', '')
template = Environment(loader=BaseLoader).from_string(template_string)
data = template.render(unit_content=unit_content)
return data
14 changes: 13 additions & 1 deletion learning_assistant/platform_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
def get_single_block(request, user_id, course_id, usage_key_string, course=None):
"""Load a single xblock."""
# pylint: disable=import-error, import-outside-toplevel
from lms.djangoapps.courseware.block_renderer import load_single_xblock
from lms.djangoapps.courseware.block_render import load_single_xblock

Check warning on line 25 in learning_assistant/platform_imports.py

View check run for this annotation

Codecov / codecov/patch

learning_assistant/platform_imports.py#L25

Added line #L25 was not covered by tests
return load_single_xblock(request, user_id, course_id, usage_key_string, course)


Expand All @@ -45,3 +45,15 @@
# pylint: disable=import-error, import-outside-toplevel
from openedx.core.lib.graph_traversals import get_children
return get_children(block)


def get_cache_course_run_data(course_run_id, fields):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to leave some documentation about what we're doing here and why - that we need to get the course ID and not the course run ID, and that this is the easiest place to get it from the platform. I think this could be a potentially unclear choice to future readers. I feel like an ADR might be overkill, so what do you think about leaving a brief comment about why we're using the catalog cache?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see you left a comment in utils.py::get_course_key 👍 . I think it might make a little more sense as a comment closer to the actual code that uses the catalog utility method, but thank you for including that comment!

"""
Return course run related data given a course run id.

This function makes use of the discovery course run cache, which is necessary because
only the discovery service stores the relation between courseruns and courses.
"""
# pylint: disable=import-error, import-outside-toplevel
from openedx.core.djangoapps.catalog.utils import get_course_run_data
return get_course_run_data(course_run_id, fields)

Check warning on line 59 in learning_assistant/platform_imports.py

View check run for this annotation

Codecov / codecov/patch

learning_assistant/platform_imports.py#L58-L59

Added lines #L58 - L59 were not covered by tests
34 changes: 34 additions & 0 deletions learning_assistant/toggles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
Toggles for learning-assistant app.
"""

WAFFLE_NAMESPACE = 'learning_assistant'

# .. toggle_name: learning_assistant.enable_course_content
# .. toggle_implementation: CourseWaffleFlag
# .. toggle_default: False
# .. toggle_description: Waffle flag to enable the course content integration with the learning assistant
# .. toggle_use_cases: temporary
# .. toggle_creation_date: 2024-01-08
# .. toggle_target_removal_date: 2024-01-31
# .. toggle_tickets: COSMO-80
ENABLE_COURSE_CONTENT = 'enable_course_content'


def _is_learning_assistant_waffle_flag_enabled(flag_name, course_key):
"""
Import and return Waffle flag for enabling the summary hook.
"""
# pylint: disable=import-outside-toplevel
try:
from openedx.core.djangoapps.waffle_utils import CourseWaffleFlag
return CourseWaffleFlag(f'{WAFFLE_NAMESPACE}.{flag_name}', __name__).is_enabled(course_key)
except ImportError:
return False

Check warning on line 27 in learning_assistant/toggles.py

View check run for this annotation

Codecov / codecov/patch

learning_assistant/toggles.py#L23-L27

Added lines #L23 - L27 were not covered by tests


def course_content_enabled(course_key):
"""
Return whether the learning_assistant.enable_course_content WaffleFlag is on.
"""
return _is_learning_assistant_waffle_flag_enabled(ENABLE_COURSE_CONTENT, course_key)
53 changes: 46 additions & 7 deletions learning_assistant/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Utils file for learning-assistant.
"""
import copy
import json
import logging

Expand All @@ -9,6 +10,8 @@
from requests.exceptions import ConnectTimeout
from rest_framework import status as http_status

from learning_assistant.platform_imports import get_cache_course_run_data

log = logging.getLogger(__name__)


Expand All @@ -22,21 +25,31 @@ def _estimated_message_tokens(message):
return int((len(message) - message.count(' ')) / chars_per_token) + json_padding


def get_reduced_message_list(system_list, message_list):
def get_reduced_message_list(prompt_template, message_list):
"""
If messages are larger than allotted token amount, return a smaller list of messages.
"""
total_system_tokens = sum(_estimated_message_tokens(system_message['content']) for system_message in system_list)
# the total number of system tokens is a sum of estimated tokens that includes the prompt template, the
# course title, and the course skills. It is necessary to include estimations for the course title and
# course skills, as the chat endpoint the prompt is being passed to is responsible for filling in the values
# for both of those variables.
total_system_tokens = (
_estimated_message_tokens(prompt_template)
+ _estimated_message_tokens('.' * 40) # average number of characters per course name is 40
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you leave a comment about why we're estimating these and including them in the count, please?

+ _estimated_message_tokens('.' * 116) # average number of characters for skill names is 116
)

max_tokens = getattr(settings, 'CHAT_COMPLETION_MAX_TOKENS', 16385)
response_tokens = getattr(settings, 'CHAT_COMPLETION_RESPONSE_TOKENS', 1000)
remaining_tokens = max_tokens - response_tokens - total_system_tokens

new_message_list = []
# use copy of list, as it is modified as part of the reduction
message_list_copy = copy.deepcopy(message_list)
total_message_tokens = 0

while total_message_tokens < remaining_tokens and len(message_list) != 0:
new_message = message_list.pop()
while total_message_tokens < remaining_tokens and len(message_list_copy) != 0:
new_message = message_list_copy.pop()
total_message_tokens += _estimated_message_tokens(new_message['content'])
if total_message_tokens >= remaining_tokens:
break
Expand All @@ -47,7 +60,34 @@ def get_reduced_message_list(system_list, message_list):
return new_message_list


def get_chat_response(system_list, message_list):
def get_course_id(course_run_id):
"""
Given a course run id (str), return the associated course key.
"""
course_data = get_cache_course_run_data(course_run_id, ['course'])
course_key = course_data['course']
return course_key


def create_request_body(prompt_template, message_list, courserun_id):
"""
Form request body to be passed to the chat endpoint.
"""
response_body = {
'context': {
'content': prompt_template,
'render': {
'doc_id': get_course_id(courserun_id),
'fields': ['skillNames', 'title']
}
},
'message_list': get_reduced_message_list(prompt_template, message_list)
}

return response_body


def get_chat_response(prompt_template, message_list, courserun_id):
"""
Pass message list to chat endpoint, as defined by the CHAT_COMPLETION_API setting.
"""
Expand All @@ -58,8 +98,7 @@ def get_chat_response(system_list, message_list):
connect_timeout = getattr(settings, 'CHAT_COMPLETION_API_CONNECT_TIMEOUT', 1)
read_timeout = getattr(settings, 'CHAT_COMPLETION_API_READ_TIMEOUT', 15)

reduced_messages = get_reduced_message_list(system_list, message_list)
body = {'message_list': system_list + reduced_messages}
body = create_request_body(prompt_template, message_list, courserun_id)

try:
response = requests.post(
Expand Down
25 changes: 10 additions & 15 deletions learning_assistant/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# If the waffle flag is false, the endpoint will force an early return.
learning_assistant_is_active = False

from learning_assistant.api import get_setup_messages
from learning_assistant.api import render_prompt_template
from learning_assistant.serializers import MessageSerializer
from learning_assistant.utils import get_chat_response

Expand All @@ -46,16 +46,16 @@ def post(self, request, course_id):
]
}
"""
course_key = CourseKey.from_string(course_id)
if not learning_assistant_is_active(course_key):
courserun_key = CourseKey.from_string(course_id)
if not learning_assistant_is_active(courserun_key):
return Response(
status=http_status.HTTP_403_FORBIDDEN,
data={'detail': 'Learning assistant not enabled for course.'}
)

# If user does not have an enrollment record, or is not staff, they should not have access
user_role = get_user_role(request.user, course_key)
enrollment_object = CourseEnrollment.get_enrollment(request.user, course_key)
user_role = get_user_role(request.user, courserun_key)
enrollment_object = CourseEnrollment.get_enrollment(request.user, courserun_key)
enrollment_mode = enrollment_object.mode if enrollment_object else None
if (
(enrollment_mode not in CourseMode.ALL_MODES)
Expand All @@ -66,12 +66,7 @@ def post(self, request, course_id):
data={'detail': 'Must be staff or have valid enrollment.'}
)

prompt_messages = get_setup_messages(course_id)
if not prompt_messages:
return Response(
status=http_status.HTTP_404_NOT_FOUND,
data={'detail': 'Learning assistant not enabled for course.'}
)
unit_id = request.query_params.get('unit_id')

message_list = request.data
serializer = MessageSerializer(data=message_list, many=True)
Expand All @@ -84,16 +79,16 @@ def post(self, request, course_id):
data={'detail': 'Invalid data', 'errors': serializer.errors}
)

# append system message to beginning of message list
message_setup = prompt_messages

log.info(
'Attempting to retrieve chat response for user_id=%(user_id)s in course_id=%(course_id)s',
{
'user_id': request.user.id,
'course_id': course_id
}
)
status_code, message = get_chat_response(message_setup, message_list)

prompt_template = render_prompt_template(request, request.user.id, course_id, unit_id)

status_code, message = get_chat_response(prompt_template, message_list, course_id)

return Response(status=status_code, data=message)
1 change: 1 addition & 0 deletions requirements/base.in
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ djangorestframework
edx-drf-extensions
edx-rest-api-client
edx-opaque-keys
jinja2
8 changes: 5 additions & 3 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ edx-rest-api-client==5.6.1
# via -r requirements/base.in
idna==3.6
# via requests
jinja2==3.1.2
# via -r requirements/base.in
markupsafe==2.1.3
# via jinja2
newrelic==9.3.0
# via edx-django-utils
pbr==6.0.0
Expand Down Expand Up @@ -96,8 +100,6 @@ stevedore==5.1.0
# edx-django-utils
# edx-opaque-keys
typing-extensions==4.9.0
# via
# asgiref
# edx-opaque-keys
# via edx-opaque-keys
urllib3==2.1.0
# via requests
4 changes: 0 additions & 4 deletions requirements/ci.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ pyproject-api==1.6.1
# via tox
requests==2.31.0
# via codecov
tomli==2.0.1
# via
# pyproject-api
# tox
tox==4.11.4
# via -r requirements/ci.in
urllib3==2.1.0
Expand Down
24 changes: 0 additions & 24 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,6 @@ edx-opaque-keys==2.5.1
# edx-drf-extensions
edx-rest-api-client==5.6.1
# via -r requirements/quality.txt
exceptiongroup==1.2.0
# via
# -r requirements/quality.txt
# pytest
filelock==3.13.1
# via
# -r requirements/ci.txt
Expand Down Expand Up @@ -336,19 +332,6 @@ text-unidecode==1.3
# via
# -r requirements/quality.txt
# python-slugify
tomli==2.0.1
# via
# -r requirements/ci.txt
# -r requirements/pip-tools.txt
# -r requirements/quality.txt
# build
# coverage
# pip-tools
# pylint
# pyproject-api
# pyproject-hooks
# pytest
# tox
tomlkit==0.12.3
# via
# -r requirements/quality.txt
Expand All @@ -358,10 +341,7 @@ tox==4.11.4
typing-extensions==4.9.0
# via
# -r requirements/quality.txt
# asgiref
# astroid
# edx-opaque-keys
# pylint
urllib3==2.1.0
# via
# -r requirements/ci.txt
Expand All @@ -376,10 +356,6 @@ wheel==0.42.0
# via
# -r requirements/pip-tools.txt
# pip-tools
zipp==3.17.0
# via
# -r requirements/pip-tools.txt
# importlib-metadata

# The following packages are considered to be unsafe in a requirements file:
# pip
Expand Down
Loading
Loading