Skip to content

Commit

Permalink
Fix issues on healthcheck; add unitest for session timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
rockychen-dpaw committed Oct 17, 2024
1 parent 279fb2c commit 273a606
Show file tree
Hide file tree
Showing 23 changed files with 1,114 additions and 497 deletions.
16 changes: 12 additions & 4 deletions authome/admin/clusteradmin.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,7 @@ class SystemUserAccessTokenAdmin(SyncObjectChangeMixin,admin.SystemUserAccessTok
def _sync_change(self,objids):
return cache.usertokens_changed(objids,True)

class Auth2ClusterAdmin(admin.ExtraToolsMixin,admin.DeleteMixin,admin.DatetimeMixin,admin.CatchModelExceptionMixin,djangoadmin.ModelAdmin):
list_display = ('clusterid','_running_status','default','endpoint','_last_heartbeat','_usergroup_status','_usergroupauthorization_status','_userflow_status','_idp_status')
readonly_fields = ('clusterid','_running_status','default','endpoint','_last_heartbeat','_usergroup_status','_usergroup_lastrefreshed','_usergroupauthorization_status','_usergroupauthorization_lastrefreshed','_userflow_status','_userflow_lastrefreshed','_idp_status','_idp_lastrefreshed','modified','registered')
fields = readonly_fields
class BaseAuth2ClusterAdmin(admin.ExtraToolsMixin,admin.DeleteMixin,admin.DatetimeMixin,admin.CatchModelExceptionMixin,djangoadmin.ModelAdmin):
ordering = ('clusterid',)
extra_tools = [("cluster status",'cluster_status',"clusterstatus")]

Expand Down Expand Up @@ -271,3 +268,14 @@ def has_add_permission(self, request, obj=None):

def has_delete_permission(self, request, obj=None):
return True

if settings.AUTH2_MONITORING_DIR:
class Auth2ClusterAdmin(BaseAuth2ClusterAdmin):
list_display = ('clusterid','_running_status','default','endpoint','_usergroup_status','_usergroupauthorization_status','_userflow_status','_idp_status')
readonly_fields = ('clusterid','_running_status','default','endpoint','_usergroup_status','_usergroup_lastrefreshed','_usergroupauthorization_status','_usergroupauthorization_lastrefreshed','_userflow_status','_userflow_lastrefreshed','_idp_status','_idp_lastrefreshed','modified','registered')
fields = readonly_fields
else:
class Auth2ClusterAdmin(BaseAuth2ClusterAdmin):
list_display = ('clusterid','_running_status','default','endpoint','_last_heartbeat','_usergroup_status','_usergroupauthorization_status','_userflow_status','_idp_status')
readonly_fields = ('clusterid','_running_status','default','endpoint','_last_heartbeat','_usergroup_status','_usergroup_lastrefreshed','_usergroupauthorization_status','_usergroupauthorization_lastrefreshed','_userflow_status','_userflow_lastrefreshed','_idp_status','_idp_lastrefreshed','modified','registered')
fields = readonly_fields
2 changes: 1 addition & 1 deletion authome/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def get_logout_url(cls):
def logout_url(self):
return self.LOGOUT_URL.format(base_url=self.base_url)

error_re = re.compile("^\s*(?P<code>[A-Z0-9]+)\s*:")
error_re = re.compile("^\\s*(?P<code>[A-Z0-9]+)\\s*:")
def process_error(self, data):
try:
super().process_error(data)
Expand Down
180 changes: 126 additions & 54 deletions authome/basetest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,76 +28,86 @@ def get(self,path, authorization=None,domain=None,url=None):

return super().get(path,headers=headers)

class BaseAuthTestCase(TestCase):
client_class = Auth2Client
home_url = reverse('home')
auth_url = reverse('auth')
auth_optional_url = reverse('auth_optional')
auth_basic_url = reverse('auth_basic')
auth_basic_optional_url = reverse('auth_basic_optional')
test_usergroups = None
class BaseTestCase(TestCase):
test_usergroupauthorization = None
test_userauthorization = None
test_users = None


def create_client(self):
return Auth2Client()

def setUp(self):
test_usergroups = None
@classmethod
def setUpClass(cls):
super(BaseTestCase,cls).setUpClass()
print("*********************************************************")
if settings.RELEASE:
print("Running in release mode")
print("Running unitest({}) in release mode".format(cls.__name__))
else:
print("Running in dev mode")
settings.AUTH_CACHE_SIZE=2000
settings.BASIC_AUTH_CACHE_SIZE=1000
settings.AUTH_BASIC_CACHE_EXPIRETIME=timedelta(seconds=3600)
settings.AUTH_CACHE_EXPIRETIME=timedelta(seconds=3600)
settings.STAFF_AUTH_CACHE_EXPIRETIME=settings.AUTH_CACHE_EXPIRETIME
settings.AUTH_CACHE_CLEAN_HOURS = [0]
settings.AUTHORIZATION_CACHE_CHECK_HOURS = [0,12]
print("Running unitest({}) in dev mode".format(cls.__name__))

settings.CHECK_AUTH_BASIC_PER_REQUEST = False
User.objects.filter(email__endswith="@gunfire.com").delete()
User.objects.filter(email__endswith="@gunfire.com.au").delete()
User.objects.filter(email__endswith="@hacker.com").delete()
UserToken.objects.all().delete()
UserGroup.objects.all().exclude(users=["*"],excluded_users__isnull=True).delete()
UserAuthorization.objects.all().delete()
def delete_testdata(self):
#delete user group authorization
if self.test_usergroupauthorization:
for groupname,domain,paths,excluded_paths in self.test_usergroupauthorization :
UserGroupAuthorization.objects.filter(usergroup=UserGroup.objects.get(name=groupname),domain=domain,paths=paths,excluded_paths=excluded_paths).delete()

#delete user authorization
if self.test_userauthorization:
for user,domain,paths,excluded_paths in self.test_userauthorization:
UserAuthorization.objects.filter(user=user,domain=domain,paths=paths,excluded_paths=excluded_paths).delete()

cache.refresh_authorization_cache(True)
if not UserGroup.objects.filter(users=["*"], excluded_users__isnull=True).exists():
public_group = UserGroup(name="Public User",groupid="PUBLIC",users=["*"])
public_group.clean()
public_group.save()
if not CustomizableUserflow.objects.filter(domain="*").exists():
default_flow = CustomizableUserflow(
domain='*',
default='default',
mfa_set="default_mfa_set",
mfa_reset="default_mfa_reset",
password_reset="default_password_reset",
verifyemail_from="[email protected]",
verifyemail_subject="test"
)
default_flow.clean()
default_flow.save()
#delete users
if self.test_users:
for user in self.test_users.values():
user.delete()

cache.clean_auth_cache(True)
cache.refresh_authorization_cache(True)

#delete UserGroup objects
if self.test_usergroups:
del_usergroups = []
created_usergroups = [(UserGroup.public_group(),self.test_usergroups)]
while created_usergroups:
parent_obj,subgroup_datas = created_usergroups.pop()
for subgroup in subgroup_datas:
if len(subgroup) == 4:
name,users,excluded_users,subgroups = subgroup
session_timeout = None
else:
name,users,excluded_users,session_timeout,subgroups = subgroup

obj = UserGroup.objects.filter(name=name,parent_group=parent_obj).first()
if obj:
del_usergroups.insert(0,obj)
if subgroups:
created_usergroups.append((obj,subgroups))

def basic_auth(self, username, password):
return 'Basic {}'.format(base64.b64encode('{}:{}'.format(username, password).encode('utf-8')).decode('utf-8'))
for usergroup in del_usergroups:
usergroup.delete()

def populate_testdata(self):
#popuate UserGroup objects
if self.test_usergroups:
uncreated_usergroups = [(UserGroup.public_group(),self.test_usergroups)]
while uncreated_usergroups:
parent_obj,subgroup_datas = uncreated_usergroups.pop()
for name,users,excluded_users,subgroups in subgroup_datas:
obj = UserGroup(name=name,groupid=name,users=users,excluded_users=excluded_users,parent_group=parent_obj)
for subgroup in subgroup_datas:
if len(subgroup) == 4:
name,users,excluded_users,subgroups = subgroup
session_timeout = None
else:
name,users,excluded_users,session_timeout,subgroups = subgroup

if users and isinstance(users,str):
users = [users]
if excluded_users and isinstance(excluded_users,str):
excluded_users = [excluded_users]

obj = UserGroup.objects.filter(name=name).first()
if obj:
obj.groupid=name
obj.users=users
obj.excluded_users=excluded_users
obj.parent_group=parent_obj
obj.session_timeout=session_timeout
else:
obj = UserGroup(name=name,groupid=name,users=users,excluded_users=excluded_users,parent_group=parent_obj,session_timeout=session_timeout)
obj.clean()
print("save usergroup={}".format(obj))
obj.save()
Expand All @@ -107,7 +117,11 @@ def populate_testdata(self):
users = OrderedDict()
if self.test_users:
for test_user in self.test_users:
obj = User(username=test_user[0],email=test_user[1])
obj = User.objects.filter(username=test_user[0]).first()
if obj:
obj.email = test_user[1]
else:
obj = User(username=test_user[0],email=test_user[1])
obj.clean()
obj.save()
users[test_user[0]] = obj
Expand Down Expand Up @@ -135,6 +149,64 @@ def populate_testdata(self):

cache.refresh_authorization_cache(True)

class BaseAuthTestCase(BaseTestCase):
client_class = Auth2Client
home_url = reverse('home')
auth_url = reverse('auth')
auth_optional_url = reverse('auth_optional')
auth_basic_url = reverse('auth_basic')
auth_basic_optional_url = reverse('auth_basic_optional')


def create_client(self):
return Auth2Client()

def setUp(self):
settings.AUTH_CACHE_SIZE=2000
settings.BASIC_AUTH_CACHE_SIZE=1000
settings.AUTH_BASIC_CACHE_EXPIRETIME=timedelta(seconds=3600)
settings.AUTH_CACHE_EXPIRETIME=timedelta(seconds=3600)
settings.STAFF_AUTH_CACHE_EXPIRETIME=settings.AUTH_CACHE_EXPIRETIME
settings.AUTH_CACHE_CLEAN_HOURS = [0]
settings.AUTHORIZATION_CACHE_CHECK_HOURS = [0,12]

settings.CHECK_AUTH_BASIC_PER_REQUEST = False
User.objects.filter(email__endswith="@gunfire.com").delete()
User.objects.filter(email__endswith="@gunfire.com.au").delete()
User.objects.filter(email__endswith="@hacker.com").delete()
UserToken.objects.all().delete()
UserGroup.objects.all().exclude(users=["*"],excluded_users__isnull=True).delete()
UserAuthorization.objects.all().delete()

cache.refresh_authorization_cache(True)
public_group = UserGroup.objects.filter(users=["*"], excluded_users__isnull=True).first()
if public_group:
public_group.session_timeout = 900
public_group.save()
else:
public_group = UserGroup(name="Public User",groupid="PUBLIC",users=["*"],session_timeout=900)
public_group.clean()
public_group.save()
if not CustomizableUserflow.objects.filter(domain="*").exists():
default_flow = CustomizableUserflow(
domain='*',
default='default',
mfa_set="default_mfa_set",
mfa_reset="default_mfa_reset",
password_reset="default_password_reset",
verifyemail_from="[email protected]",
verifyemail_subject="test"
)
default_flow.clean()
default_flow.save()

cache.clean_auth_cache(True)
cache.refresh_authorization_cache(True)


def basic_auth(self, username, password):
return 'Basic {}'.format(base64.b64encode('{}:{}'.format(username, password).encode('utf-8')).decode('utf-8'))

class BaseAuthCacheTestCase(BaseAuthTestCase):
def setUp(self):
super().setUp()
Expand Down
13 changes: 5 additions & 8 deletions authome/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,11 @@ def _cache(self):
else:
return self._class(self._servers, **self._options)

def make_and_validate_key(self, key, version=None):
return key

def ttl(self, key):
return self._cache.ttl(key)
return self._cache.ttl(self.make_and_validate_key(key))

def expire(self, key,timeout):
return self._cache.expire(key, timeout)
return self._cache.expire(self.make_and_validate_key(key), timeout)

def _parse_server(self,server=None):
"""
Expand Down Expand Up @@ -283,16 +280,16 @@ def ping(self):
working = True
if isinstance(redisclients,list):
for redisclient in redisclients:
starttime = timezone.local()
starttime = timezone.localtime()
status = self.ping_redis(redisclient)
pingstatus[redisclient[0]] = {"ping":status[0],"pingtime":Processtime((timezone.localtime() - starttime).total_seconds())}
if not status[0]:
working = False
if status[1]:
pingstatus[redisclient[0]]["error"] = status[1]
else:
starttime = timezone.local()
status = self.ping_redis(redisclients,pingtimes)
starttime = timezone.localtime()
status = self.ping_redis(redisclients)
pingstatus[redisclients[0]] = {"ping":status[0],"pingtime":Processtime((timezone.localtime() - starttime).total_seconds())}
if not status[0]:
working = False
Expand Down
89 changes: 89 additions & 0 deletions authome/response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import itertools
import mimetypes
import io
import os

from django.http import HttpResponse,StreamingHttpResponse
from django.utils.http import content_disposition_header

def filereaderfactory(filein,block_size):
def _filereader():
return filein.read(block_size)

return _filereader

class MultiFileSegmentsResponse(StreamingHttpResponse):
"""
A streaming HTTP response class optimized for a file which is consisted with multi file segments..
"""

block_size = 4096

def __init__(self, *args, as_attachment=False, filename="", **kwargs):
self.as_attachment = as_attachment
self.filename = filename
self._no_explicit_content_type = (
"content_type" not in kwargs or kwargs["content_type"] is None
)
super().__init__(*args, **kwargs)

def _set_streaming_content(self, files):
streams = []
self.files = files
for file in files:
filein = open(file,'rb')
self._resource_closers.append(filein.close)
streams.append(filein)

self.set_headers(streams)
super()._set_streaming_content(itertools.chain(*[iter(filereaderfactory(filein,self.block_size), b"") for filein in streams]))

def set_headers(self, streams):
"""
Set some common response headers (Content-Length, Content-Type, and
Content-Disposition) based on the `filelike` response content.
"""
content_length = 0
for i in range(len(streams)):
filelike = streams[i]
filename = self.files[i]
seekable = hasattr(filelike, "seek") and (
not hasattr(filelike, "seekable") or filelike.seekable()
)
if hasattr(filelike, "tell"):
if seekable:
initial_position = filelike.tell()
filelike.seek(0, io.SEEK_END)
content_length += filelike.tell() - initial_position
filelike.seek(initial_position)
elif hasattr(filelike, "getbuffer"):
content_length += filelike.getbuffer().nbytes - filelike.tell()
elif os.path.exists(filename):
content_length += os.path.getsize(filename) - filelike.tell()
elif seekable:
length += sum(iter(lambda: len(filelike.read(self.block_size)), 0))
content_length += length
filelike.seek(-1 * length, io.SEEK_END)
self.headers["Content-Length"] = (content_length)
filename = os.path.basename(self.filename or filename)
if self._no_explicit_content_type:
if filename:
content_type, encoding = mimetypes.guess_type(filename)
# Encoding isn't set to prevent browsers from automatically
# uncompressing files.
content_type = {
"bzip2": "application/x-bzip",
"gzip": "application/gzip",
"xz": "application/x-xz",
}.get(encoding, content_type)
self.headers["Content-Type"] = (
content_type or "application/octet-stream"
)
else:
self.headers["Content-Type"] = "application/octet-stream"

if content_disposition := content_disposition_header(
self.as_attachment, filename
):
self.headers["Content-Disposition"] = content_disposition

Loading

0 comments on commit 273a606

Please sign in to comment.