diff --git a/segments/helpers.py b/segments/helpers.py index a92a4b0..57c4d23 100644 --- a/segments/helpers.py +++ b/segments/helpers.py @@ -2,7 +2,7 @@ import redis from django.db import connections -from django.db.models.query_utils import InvalidQuery +from django.core.exceptions import FieldError from segments import app_settings @@ -247,7 +247,7 @@ def execute_raw_user_query(self, sql): Helper that returns an array containing a RawQuerySet of user ids and their total count. """ if sql is None or not isinstance(sql, str) or "select" not in sql.lower(): - raise InvalidQuery + raise FieldError with connections[app_settings.SEGMENTS_EXEC_CONNECTION].cursor() as cursor: # Fetch the raw queryset of ids and count them diff --git a/segments/models.py b/segments/models.py index 8d0eab7..fb3906a 100644 --- a/segments/models.py +++ b/segments/models.py @@ -1,7 +1,7 @@ import logging from django.db import models, DatabaseError, OperationalError -from django.db.models.query_utils import InvalidQuery +from django.core.exceptions import FieldError from django.conf import settings from django.db.models import signals from django.utils import timezone @@ -27,7 +27,7 @@ def _wrapper(*args, **kwargs): try: return fn(*args, **kwargs) - except InvalidQuery: + except FieldError: raise SegmentExecutionError( "SQL definition must include the primary key of the %s model" % settings.AUTH_USER_MODEL diff --git a/segments/tests/test_helpers.py b/segments/tests/test_helpers.py index cd6ad0a..9eefb7c 100644 --- a/segments/tests/test_helpers.py +++ b/segments/tests/test_helpers.py @@ -1,5 +1,5 @@ import fakeredis -from django.db.models.query_utils import InvalidQuery +from django.core.exceptions import FieldError from django.db.utils import OperationalError from django.test import TestCase @@ -96,7 +96,7 @@ def setUp(self): def test_invalid_raw_user_query_raises_exception(self): empty_queries = ["", None, 1, True, "any string that does not contain s.elect"] for query in empty_queries: - with self.assertRaises(InvalidQuery, msg=f'Passed query: "{query}"') as cm: + with self.assertRaises(FieldError, msg=f'Passed query: "{query}"') as cm: generator = self.helper.execute_raw_user_query(query) for _ in generator: pass