Skip to content

Commit

Permalink
feat: integrate system prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
alangsto committed Jan 4, 2024
1 parent 1ed55c7 commit fb27c83
Show file tree
Hide file tree
Showing 19 changed files with 216 additions and 184 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ Change Log

.. There should always be an "Unreleased" section for changes pending release.
Unreleased
**********
2.0.0 - 2024-01-03
******************
* Add content cache
* Integrate system prompt setting

1.5.0 - 2023-10-18
******************
Expand Down
2 changes: 1 addition & 1 deletion learning_assistant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
Plugin for a learning assistant backend, intended for use within edx-platform.
"""

__version__ = '1.5.0'
__version__ = '2.0.0'

default_app_config = 'learning_assistant.apps.LearningAssistantConfig' # pylint: disable=invalid-name
19 changes: 19 additions & 0 deletions learning_assistant/api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""
Library for the learning_assistant app.
"""
import logging

from django.conf import settings
from django.core.cache import cache
from edx_django_utils.cache import get_cache_key
from jinja2 import BaseLoader, Environment

from learning_assistant.constants import ACCEPTED_CATEGORY_TYPES, CATEGORY_TYPE_MAP
from learning_assistant.models import CoursePrompt
Expand All @@ -16,6 +19,8 @@
)
from learning_assistant.text_utils import html_to_text

log = logging.getLogger(__name__)


def get_deserialized_prompt_content_by_course_id(course_id):
"""
Expand Down Expand Up @@ -112,3 +117,17 @@ def get_block_content(request, user_id, course_id, unit_usage_key):
cache.set(cache_key, cache_data, getattr(settings, 'LEARNING_ASSISTANT_CACHE_TIMEOUT', 360))

return cache_data['content_length'], cache_data['content_items']


def render_prompt_template(request, user_id, course_id, unit_usage_key):
"""
Return a rendered prompt template, specified by the LEARNING_ASSISTANT_PROMPT_TEMPLATE setting.
"""
unit_content = ''
if unit_usage_key:
_, unit_content = get_block_content(request, user_id, course_id, unit_usage_key)

template_string = getattr(settings, 'LEARNING_ASSISTANT_PROMPT_TEMPLATE', '')
template = Environment(loader=BaseLoader).from_string(template_string)
data = template.render(unit_content=unit_content)
return data
14 changes: 13 additions & 1 deletion learning_assistant/platform_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_text_transcript(video_block):
def get_single_block(request, user_id, course_id, usage_key_string, course=None):
"""Load a single xblock."""
# pylint: disable=import-error, import-outside-toplevel
from lms.djangoapps.courseware.block_renderer import load_single_xblock
from lms.djangoapps.courseware.block_render import load_single_xblock

Check warning on line 25 in learning_assistant/platform_imports.py

View check run for this annotation

Codecov / codecov/patch

learning_assistant/platform_imports.py#L25

Added line #L25 was not covered by tests
return load_single_xblock(request, user_id, course_id, usage_key_string, course)


Expand All @@ -45,3 +45,15 @@ def block_get_children(block):
# pylint: disable=import-error, import-outside-toplevel
from openedx.core.lib.graph_traversals import get_children
return get_children(block)


def get_cache_course_run_data(course_run_id, fields):
"""
Return course run related data given a course run id.
This function makes use of the discovery course run cache, which is necessary because
only the discovery service stores the relation between courseruns and courses.
"""
# pylint: disable=import-error, import-outside-toplevel
from openedx.core.djangoapps.catalog.utils import get_course_run_data
return get_course_run_data(course_run_id, fields)

Check warning on line 59 in learning_assistant/platform_imports.py

View check run for this annotation

Codecov / codecov/patch

learning_assistant/platform_imports.py#L58-L59

Added lines #L58 - L59 were not covered by tests
53 changes: 46 additions & 7 deletions learning_assistant/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Utils file for learning-assistant.
"""
import copy
import json
import logging

Expand All @@ -9,6 +10,8 @@
from requests.exceptions import ConnectTimeout
from rest_framework import status as http_status

from learning_assistant.platform_imports import get_cache_course_run_data

log = logging.getLogger(__name__)


Expand All @@ -22,21 +25,31 @@ def _estimated_message_tokens(message):
return int((len(message) - message.count(' ')) / chars_per_token) + json_padding


def get_reduced_message_list(system_list, message_list):
def get_reduced_message_list(prompt_template, message_list):
"""
If messages are larger than allotted token amount, return a smaller list of messages.
"""
total_system_tokens = sum(_estimated_message_tokens(system_message['content']) for system_message in system_list)
# the total number of system tokens is a sum of estimated tokens that includes the prompt template, the
# course title, and the course skills. It is necessary to include estimations for the course title and
# course skills, as the chat endpoint the prompt is being passed to is responsible for filling in the values
# for both of those variables.
total_system_tokens = (
_estimated_message_tokens(prompt_template)
+ _estimated_message_tokens('.' * 40) # average number of characters per course name is 40
+ _estimated_message_tokens('.' * 116) # average number of characters for skill names is 116
)

max_tokens = getattr(settings, 'CHAT_COMPLETION_MAX_TOKENS', 16385)
response_tokens = getattr(settings, 'CHAT_COMPLETION_RESPONSE_TOKENS', 1000)
remaining_tokens = max_tokens - response_tokens - total_system_tokens

new_message_list = []
# use copy of list, as it is modified as part of the reduction
message_list_copy = copy.deepcopy(message_list)
total_message_tokens = 0

while total_message_tokens < remaining_tokens and len(message_list) != 0:
new_message = message_list.pop()
while total_message_tokens < remaining_tokens and len(message_list_copy) != 0:
new_message = message_list_copy.pop()
total_message_tokens += _estimated_message_tokens(new_message['content'])
if total_message_tokens >= remaining_tokens:
break
Expand All @@ -47,7 +60,34 @@ def get_reduced_message_list(system_list, message_list):
return new_message_list


def get_chat_response(system_list, message_list):
def get_course_id(course_run_id):
"""
Given a course run id (str), return the associated course key.
"""
course_data = get_cache_course_run_data(course_run_id, ['course'])
course_key = course_data['course']
return course_key


def create_request_body(prompt_template, message_list, courserun_id):
"""
Form request body to be passed to the chat endpoint.
"""
response_body = {
'context': {
'content': prompt_template,
'render': {
'doc_id': get_course_id(courserun_id),
'fields': ['skillNames', 'title']
}
},
'message_list': get_reduced_message_list(prompt_template, message_list)
}

return response_body


def get_chat_response(prompt_template, message_list, courserun_id):
"""
Pass message list to chat endpoint, as defined by the CHAT_COMPLETION_API setting.
"""
Expand All @@ -58,8 +98,7 @@ def get_chat_response(system_list, message_list):
connect_timeout = getattr(settings, 'CHAT_COMPLETION_API_CONNECT_TIMEOUT', 1)
read_timeout = getattr(settings, 'CHAT_COMPLETION_API_READ_TIMEOUT', 15)

reduced_messages = get_reduced_message_list(system_list, message_list)
body = {'message_list': system_list + reduced_messages}
body = create_request_body(prompt_template, message_list, courserun_id)

try:
response = requests.post(
Expand Down
25 changes: 10 additions & 15 deletions learning_assistant/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# If the waffle flag is false, the endpoint will force an early return.
learning_assistant_is_active = False

from learning_assistant.api import get_setup_messages
from learning_assistant.api import render_prompt_template
from learning_assistant.serializers import MessageSerializer
from learning_assistant.utils import get_chat_response

Expand All @@ -46,16 +46,16 @@ def post(self, request, course_id):
]
}
"""
course_key = CourseKey.from_string(course_id)
if not learning_assistant_is_active(course_key):
courserun_key = CourseKey.from_string(course_id)
if not learning_assistant_is_active(courserun_key):
return Response(
status=http_status.HTTP_403_FORBIDDEN,
data={'detail': 'Learning assistant not enabled for course.'}
)

# If user does not have an enrollment record, or is not staff, they should not have access
user_role = get_user_role(request.user, course_key)
enrollment_object = CourseEnrollment.get_enrollment(request.user, course_key)
user_role = get_user_role(request.user, courserun_key)
enrollment_object = CourseEnrollment.get_enrollment(request.user, courserun_key)
enrollment_mode = enrollment_object.mode if enrollment_object else None
if (
(enrollment_mode not in CourseMode.ALL_MODES)
Expand All @@ -66,12 +66,7 @@ def post(self, request, course_id):
data={'detail': 'Must be staff or have valid enrollment.'}
)

prompt_messages = get_setup_messages(course_id)
if not prompt_messages:
return Response(
status=http_status.HTTP_404_NOT_FOUND,
data={'detail': 'Learning assistant not enabled for course.'}
)
unit_id = request.query_params.get('unit_id')

message_list = request.data
serializer = MessageSerializer(data=message_list, many=True)
Expand All @@ -84,16 +79,16 @@ def post(self, request, course_id):
data={'detail': 'Invalid data', 'errors': serializer.errors}
)

# append system message to beginning of message list
message_setup = prompt_messages

log.info(
'Attempting to retrieve chat response for user_id=%(user_id)s in course_id=%(course_id)s',
{
'user_id': request.user.id,
'course_id': course_id
}
)
status_code, message = get_chat_response(message_setup, message_list)

prompt_template = render_prompt_template(request, request.user.id, course_id, unit_id)

status_code, message = get_chat_response(prompt_template, message_list, course_id)

return Response(status=status_code, data=message)
1 change: 1 addition & 0 deletions requirements/base.in
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ djangorestframework
edx-drf-extensions
edx-rest-api-client
edx-opaque-keys
jinja2
10 changes: 6 additions & 4 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,15 @@ edx-rest-api-client==5.6.1
# via -r requirements/base.in
idna==3.6
# via requests
jinja2==3.1.2
# via -r requirements/base.in
markupsafe==2.1.3
# via jinja2
newrelic==9.3.0
# via edx-django-utils
pbr==6.0.0
# via stevedore
psutil==5.9.6
psutil==5.9.7
# via edx-django-utils
pycparser==2.21
# via cffi
Expand Down Expand Up @@ -96,8 +100,6 @@ stevedore==5.1.0
# edx-django-utils
# edx-opaque-keys
typing-extensions==4.9.0
# via
# asgiref
# edx-opaque-keys
# via edx-opaque-keys
urllib3==2.1.0
# via requests
6 changes: 1 addition & 5 deletions requirements/ci.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ codecov==2.1.13
# via -r requirements/ci.in
colorama==0.4.6
# via tox
coverage==7.3.2
coverage==7.4.0
# via codecov
distlib==0.3.8
# via virtualenv
Expand All @@ -40,10 +40,6 @@ pyproject-api==1.6.1
# via tox
requests==2.31.0
# via codecov
tomli==2.0.1
# via
# pyproject-api
# tox
tox==4.11.4
# via -r requirements/ci.in
urllib3==2.1.0
Expand Down
Loading

0 comments on commit fb27c83

Please sign in to comment.