Skip to content

Commit

Permalink
feat: integrate system prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
alangsto committed Jan 3, 2024
1 parent 1ed55c7 commit a23aa4e
Show file tree
Hide file tree
Showing 19 changed files with 196 additions and 178 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ Change Log

.. There should always be an "Unreleased" section for changes pending release.
Unreleased
**********
2.0.0 - 2024-01-03
******************
* Add content cache
* Integrate system prompt setting

1.5.0 - 2023-10-18
******************
Expand Down
2 changes: 1 addition & 1 deletion learning_assistant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
Plugin for a learning assistant backend, intended for use within edx-platform.
"""

__version__ = '1.5.0'
__version__ = '2.0.0'

default_app_config = 'learning_assistant.apps.LearningAssistantConfig' # pylint: disable=invalid-name
15 changes: 15 additions & 0 deletions learning_assistant/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from django.conf import settings
from django.core.cache import cache
from edx_django_utils.cache import get_cache_key
from jinja2 import BaseLoader, Environment

from learning_assistant.constants import ACCEPTED_CATEGORY_TYPES, CATEGORY_TYPE_MAP
from learning_assistant.models import CoursePrompt
Expand Down Expand Up @@ -112,3 +113,17 @@ def get_block_content(request, user_id, course_id, unit_usage_key):
cache.set(cache_key, cache_data, getattr(settings, 'LEARNING_ASSISTANT_CACHE_TIMEOUT', 360))

return cache_data['content_length'], cache_data['content_items']


def render_prompt_template(request, user_id, course_id, unit_usage_key):
"""
Return a rendered prompt template, specified by the LEARNING_ASSISTANT_PROMPT_TEMPLATE setting.
"""
unit_content = ''
if unit_usage_key:
unit_content = get_block_content(request, user_id, course_id, unit_usage_key)

template_string = getattr(settings, 'LEARNING_ASSISTANT_PROMPT_TEMPLATE', '')
template = Environment(loader=BaseLoader).from_string(template_string)
data = template.render(unit_content=unit_content)
return data
7 changes: 7 additions & 0 deletions learning_assistant/platform_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,10 @@ def block_get_children(block):
# pylint: disable=import-error, import-outside-toplevel
from openedx.core.lib.graph_traversals import get_children
return get_children(block)


def get_cache_course_run_data(course_run_id, fields):
"""Return course run related data given a course run id."""
# pylint: disable=import-error, import-outside-toplevel
from openedx.core.djangoapps.catalog.utils import get_course_run_data
return get_course_run_data(course_run_id, fields)
52 changes: 45 additions & 7 deletions learning_assistant/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Utils file for learning-assistant.
"""
import copy
import json
import logging

Expand All @@ -9,6 +10,8 @@
from requests.exceptions import ConnectTimeout
from rest_framework import status as http_status

from learning_assistant.platform_imports import get_cache_course_run_data

log = logging.getLogger(__name__)


Expand All @@ -22,21 +25,27 @@ def _estimated_message_tokens(message):
return int((len(message) - message.count(' ')) / chars_per_token) + json_padding


def get_reduced_message_list(system_list, message_list):
def get_reduced_message_list(prompt_template, message_list):
"""
If messages are larger than allotted token amount, return a smaller list of messages.
"""
total_system_tokens = sum(_estimated_message_tokens(system_message['content']) for system_message in system_list)
total_system_tokens = (
_estimated_message_tokens(prompt_template)
+ _estimated_message_tokens('.' * 40) # average number of characters per course name is 40
+ _estimated_message_tokens('.' * 116) # average number of characters for skill names is 116
)

max_tokens = getattr(settings, 'CHAT_COMPLETION_MAX_TOKENS', 16385)
response_tokens = getattr(settings, 'CHAT_COMPLETION_RESPONSE_TOKENS', 1000)
remaining_tokens = max_tokens - response_tokens - total_system_tokens

new_message_list = []
# use copy of list, as it is modified as part of the reduction
message_list_copy = copy.deepcopy(message_list)
total_message_tokens = 0

while total_message_tokens < remaining_tokens and len(message_list) != 0:
new_message = message_list.pop()
while total_message_tokens < remaining_tokens and len(message_list_copy) != 0:
new_message = message_list_copy.pop()
total_message_tokens += _estimated_message_tokens(new_message['content'])
if total_message_tokens >= remaining_tokens:
break
Expand All @@ -47,7 +56,37 @@ def get_reduced_message_list(system_list, message_list):
return new_message_list


def get_chat_response(system_list, message_list):
def get_course_key(courserun_id):
"""
Given a courserun id (str), return the associated course key.
This function makes use of the discovery course run cache, which is necessary because
only the discovery service stores the relation between courseruns and courses.
"""
course_data = get_cache_course_run_data(courserun_id, ['course'])
course_key = course_data['course']
return course_key


def create_request_body(prompt_template, message_list, courserun_id):
"""
Form request body to be passed to the chat endpoint.
"""
response_body = {
'context': {
'content': prompt_template,
'render': {
'doc_id': get_course_key(courserun_id),
'fields': ['skill_names', 'title']
}
},
'message_list': get_reduced_message_list(prompt_template, message_list)
}

return response_body


def get_chat_response(prompt_template, message_list, courserun_id):
"""
Pass message list to chat endpoint, as defined by the CHAT_COMPLETION_API setting.
"""
Expand All @@ -58,8 +97,7 @@ def get_chat_response(system_list, message_list):
connect_timeout = getattr(settings, 'CHAT_COMPLETION_API_CONNECT_TIMEOUT', 1)
read_timeout = getattr(settings, 'CHAT_COMPLETION_API_READ_TIMEOUT', 15)

reduced_messages = get_reduced_message_list(system_list, message_list)
body = {'message_list': system_list + reduced_messages}
body = create_request_body(prompt_template, message_list, courserun_id)

try:
response = requests.post(
Expand Down
17 changes: 6 additions & 11 deletions learning_assistant/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# If the waffle flag is false, the endpoint will force an early return.
learning_assistant_is_active = False

from learning_assistant.api import get_setup_messages
from learning_assistant.api import render_prompt_template
from learning_assistant.serializers import MessageSerializer
from learning_assistant.utils import get_chat_response

Expand Down Expand Up @@ -66,12 +66,7 @@ def post(self, request, course_id):
data={'detail': 'Must be staff or have valid enrollment.'}
)

prompt_messages = get_setup_messages(course_id)
if not prompt_messages:
return Response(
status=http_status.HTTP_404_NOT_FOUND,
data={'detail': 'Learning assistant not enabled for course.'}
)
unit_usage_key = request.POST.get('unit_usage_key', None)

message_list = request.data
serializer = MessageSerializer(data=message_list, many=True)
Expand All @@ -84,16 +79,16 @@ def post(self, request, course_id):
data={'detail': 'Invalid data', 'errors': serializer.errors}
)

# append system message to beginning of message list
message_setup = prompt_messages

log.info(
'Attempting to retrieve chat response for user_id=%(user_id)s in course_id=%(course_id)s',
{
'user_id': request.user.id,
'course_id': course_id
}
)
status_code, message = get_chat_response(message_setup, message_list)

prompt_template = render_prompt_template(request, request.user.id, course_id, unit_usage_key)

status_code, message = get_chat_response(prompt_template, message_list, course_id)

return Response(status=status_code, data=message)
1 change: 1 addition & 0 deletions requirements/base.in
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ djangorestframework
edx-drf-extensions
edx-rest-api-client
edx-opaque-keys
jinja2
10 changes: 6 additions & 4 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,15 @@ edx-rest-api-client==5.6.1
# via -r requirements/base.in
idna==3.6
# via requests
jinja2==3.1.2
# via -r requirements/base.in
markupsafe==2.1.3
# via jinja2
newrelic==9.3.0
# via edx-django-utils
pbr==6.0.0
# via stevedore
psutil==5.9.6
psutil==5.9.7
# via edx-django-utils
pycparser==2.21
# via cffi
Expand Down Expand Up @@ -96,8 +100,6 @@ stevedore==5.1.0
# edx-django-utils
# edx-opaque-keys
typing-extensions==4.9.0
# via
# asgiref
# edx-opaque-keys
# via edx-opaque-keys
urllib3==2.1.0
# via requests
6 changes: 1 addition & 5 deletions requirements/ci.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ codecov==2.1.13
# via -r requirements/ci.in
colorama==0.4.6
# via tox
coverage==7.3.2
coverage==7.4.0
# via codecov
distlib==0.3.8
# via virtualenv
Expand All @@ -40,10 +40,6 @@ pyproject-api==1.6.1
# via tox
requests==2.31.0
# via codecov
tomli==2.0.1
# via
# pyproject-api
# tox
tox==4.11.4
# via -r requirements/ci.in
urllib3==2.1.0
Expand Down
40 changes: 6 additions & 34 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ colorama==0.4.6
# via
# -r requirements/ci.txt
# tox
coverage[toml]==7.3.2
coverage[toml]==7.4.0
# via
# -r requirements/ci.txt
# -r requirements/quality.txt
Expand All @@ -77,7 +77,7 @@ cryptography==41.0.7
# pyjwt
ddt==1.7.0
# via -r requirements/quality.txt
diff-cover==8.0.1
diff-cover==8.0.2
# via -r requirements/dev.in
dill==0.3.7
# via
Expand Down Expand Up @@ -136,10 +136,6 @@ edx-opaque-keys==2.5.1
# edx-drf-extensions
edx-rest-api-client==5.6.1
# via -r requirements/quality.txt
exceptiongroup==1.2.0
# via
# -r requirements/quality.txt
# pytest
filelock==3.13.1
# via
# -r requirements/ci.txt
Expand All @@ -150,15 +146,11 @@ idna==3.6
# -r requirements/ci.txt
# -r requirements/quality.txt
# requests
importlib-metadata==7.0.0
# via
# -r requirements/pip-tools.txt
# build
iniconfig==2.0.0
# via
# -r requirements/quality.txt
# pytest
isort==5.13.1
isort==5.13.2
# via
# -r requirements/quality.txt
# pylint
Expand All @@ -167,7 +159,7 @@ jinja2==3.1.2
# -r requirements/quality.txt
# code-annotations
# diff-cover
lxml==4.9.3
lxml==5.0.0
# via edx-i18n-tools
markupsafe==2.1.3
# via
Expand Down Expand Up @@ -214,7 +206,7 @@ pluggy==1.3.0
# tox
polib==1.2.0
# via edx-i18n-tools
psutil==5.9.6
psutil==5.9.7
# via
# -r requirements/quality.txt
# edx-django-utils
Expand Down Expand Up @@ -271,7 +263,7 @@ pyproject-hooks==1.0.0
# via
# -r requirements/pip-tools.txt
# build
pytest==7.4.3
pytest==7.4.4
# via
# -r requirements/quality.txt
# pytest-cov
Expand Down Expand Up @@ -336,19 +328,6 @@ text-unidecode==1.3
# via
# -r requirements/quality.txt
# python-slugify
tomli==2.0.1
# via
# -r requirements/ci.txt
# -r requirements/pip-tools.txt
# -r requirements/quality.txt
# build
# coverage
# pip-tools
# pylint
# pyproject-api
# pyproject-hooks
# pytest
# tox
tomlkit==0.12.3
# via
# -r requirements/quality.txt
Expand All @@ -358,10 +337,7 @@ tox==4.11.4
typing-extensions==4.9.0
# via
# -r requirements/quality.txt
# asgiref
# astroid
# edx-opaque-keys
# pylint
urllib3==2.1.0
# via
# -r requirements/ci.txt
Expand All @@ -376,10 +352,6 @@ wheel==0.42.0
# via
# -r requirements/pip-tools.txt
# pip-tools
zipp==3.17.0
# via
# -r requirements/pip-tools.txt
# importlib-metadata

# The following packages are considered to be unsafe in a requirements file:
# pip
Expand Down
Loading

0 comments on commit a23aa4e

Please sign in to comment.