-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add GPU scheduling app with models, forms, views, and URLs
- Loading branch information
1 parent
b415308
commit 21739ca
Showing
10 changed files
with
256 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from django.contrib import admin | ||
|
||
# Register your models here. | ||
|
||
from .models import GPUBooking |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from django import forms | ||
|
||
from .models import GPUBooking | ||
from django.conf import settings | ||
from MAIA.dashboard_utils import get_groups_in_keycloak, get_pending_projects | ||
|
||
|
||
class GPUBookingForm(forms.ModelForm): | ||
def __init__(self, *args, **kwargs): | ||
super(GPUBookingForm, self).__init__(*args, **kwargs) | ||
maia_groups = get_groups_in_keycloak(settings= settings) | ||
pending_projects = get_pending_projects(settings=settings) | ||
|
||
for pending_project in pending_projects: | ||
maia_groups[pending_project] = pending_project + " (Pending)" | ||
|
||
self.fields['namespace'] = forms.ChoiceField( | ||
choices=[(maia_group,maia_group) for maia_group in maia_groups.values()], | ||
widget=forms.Select(attrs={ | ||
'class': "form-select text-center fw-bold", | ||
'style': 'max-width: auto;', | ||
} | ||
) | ||
) | ||
self.fields['user_email'] = forms.EmailField( | ||
widget=forms.EmailInput( | ||
attrs={ | ||
"placeholder": "Your Email.", | ||
"class": "form-control" | ||
} | ||
)) | ||
|
||
self.fields['gpu'] = forms.ChoiceField( | ||
choices=[(gpu,gpu) for gpu in settings.GPU_LIST], | ||
widget=forms.Select(attrs={ | ||
'class': "form-select text-center fw-bold", | ||
'style': 'max-width: auto;', | ||
}) | ||
) | ||
|
||
self.fields['start_date'] = forms.DateField(widget=forms.TextInput(attrs={'class': 'form-control', 'type': 'date'})) | ||
|
||
self.fields['end_date'] = forms.DateField(widget=forms.TextInput(attrs={'class': 'form-control', 'type': 'date'})) | ||
|
||
def clean(self): | ||
cleaned_data = super().clean() | ||
start_date = cleaned_data.get("start_date") | ||
end_date = cleaned_data.get("end_date") | ||
user_email = cleaned_data.get("user_email") | ||
|
||
if start_date and end_date: | ||
if end_date <= start_date: | ||
self.add_error('end_date', "End date must be after start date.") | ||
elif (end_date - start_date).days > 60: | ||
self.add_error('end_date', "The maximum booking duration is 60 days.") | ||
else: | ||
# Check existing bookings for the same user | ||
existing_bookings = GPUBooking.objects.filter(user_email=user_email) | ||
total_days = 0 | ||
for booking in existing_bookings: | ||
total_days += (booking.end_date - booking.start_date).days | ||
|
||
if total_days > 60: | ||
self.add_error('end_date', "The total booking duration for this user exceeds 60 days.") | ||
|
||
return cleaned_data | ||
|
||
|
||
|
||
class Meta: | ||
model = GPUBooking | ||
fields = ('namespace','gpu', 'user_email','start_date', 'end_date') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from django.db import models | ||
|
||
class GPUBooking(models.Model): | ||
class Meta: | ||
app_label = 'gpu_scheduler' | ||
user_email = models.CharField(max_length=255, ) | ||
start_date = models.DateTimeField() | ||
end_date = models.DateTimeField() | ||
gpu = models.CharField(max_length=255) | ||
namespace = models.CharField(max_length=255) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from django.test import TestCase | ||
|
||
# Create your tests here. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# -*- encoding: utf-8 -*- | ||
""" | ||
Copyright (c) 2019 - present AppSeed.us | ||
""" | ||
|
||
from django.urls import path | ||
|
||
|
||
from django.urls import path | ||
from .views import GPUSchedulabilityAPIView, book_gpu, gpu_booking_info # Import your API view | ||
|
||
urlpatterns = [ | ||
path('gpu-schedulability/', GPUSchedulabilityAPIView.as_view(), name='gpu_schedulability'), | ||
path('', book_gpu, name='gpu_booking_form'), | ||
path('my-bookings/', gpu_booking_info, name='book_gpu'), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
from django.shortcuts import render | ||
from django.shortcuts import render, redirect | ||
from django.contrib.auth.decorators import login_required | ||
from django.http import JsonResponse | ||
from django.views.decorators.csrf import csrf_exempt | ||
from django.utils.decorators import method_decorator | ||
from rest_framework.views import APIView | ||
from rest_framework.response import Response | ||
from rest_framework.permissions import AllowAny | ||
from .models import GPUBooking | ||
from django.conf import settings | ||
from datetime import datetime, timezone | ||
from .forms import GPUBookingForm | ||
from MAIA.dashboard_utils import get_namespaces | ||
|
||
|
||
@method_decorator(csrf_exempt, name='dispatch') # 🚀 This disables CSRF for this API | ||
class GPUSchedulabilityAPIView(APIView): | ||
permission_classes = [AllowAny] # 🚀 Allow requests without authentication or CSRF | ||
|
||
def post(self, request, *args, **kwargs): | ||
try: | ||
user_email = request.data.get("user_email") | ||
if not user_email: | ||
return Response({"error": "Missing user_email"}, status=400) | ||
|
||
secret_token = request.data.get("token") | ||
if not secret_token or secret_token != settings.SECRET_KEY: | ||
return Response({"error": "Invalid or missing secret token"}, status=403) | ||
|
||
if "booking" in request.data: | ||
booking_data = request.data["booking"] | ||
|
||
# Calculate the total number of days for existing bookings | ||
existing_bookings = GPUBooking.objects.filter(user_email=user_email) | ||
total_days = sum( | ||
(booking.end_date - booking.start_date).days | ||
for booking in existing_bookings | ||
) | ||
|
||
# Calculate the number of days for the new booking | ||
ending_time = datetime.strptime(booking_data["ending_time"], "%Y-%m-%d %H:%M:%S") | ||
starting_time = datetime.strptime(booking_data["starting_time"], "%Y-%m-%d %H:%M:%S") | ||
new_booking_days = (ending_time - starting_time).days | ||
|
||
# Verify that the sum of existing bookings and the new booking does not exceed 60 days | ||
if total_days + new_booking_days > 60: | ||
return Response({"error": "Total booking days exceed the limit of 60 days"}, status=400) | ||
|
||
# Create the new booking | ||
GPUBooking.objects.create( | ||
user_email=user_email, | ||
start_date=booking_data["starting_time"], | ||
end_date=booking_data["ending_time"] | ||
) | ||
return Response({"message": "Booking created successfully"}) | ||
|
||
try: | ||
user_statuses = GPUBooking.objects.filter(user_email=user_email) | ||
|
||
is_schedulable = False | ||
|
||
current_time = datetime.now(timezone.utc) | ||
is_schedulable = any( | ||
status.start_date <= current_time and status.end_date >= current_time | ||
for status in user_statuses | ||
) | ||
if is_schedulable: | ||
status = next(status for status in user_statuses if status.start_date <= current_time and status.end_date >= current_time) | ||
return Response({"schedulable": is_schedulable, "until": status.end_date}) | ||
else: | ||
return Response({"schedulable": is_schedulable, "until": None}) | ||
except GPUBooking.DoesNotExist: | ||
return Response({"schedulable": False, "until": None}) # Default to not schedulable if not found | ||
|
||
except Exception as e: | ||
return JsonResponse({"error": str(e)}, status=400) | ||
|
||
|
||
@login_required(login_url="/maia/login/") | ||
def book_gpu(request): | ||
msg = None | ||
success = False | ||
|
||
email = request.user.email | ||
|
||
id_token = request.session.get('oidc_id_token') | ||
groups = request.user.groups.all() | ||
|
||
namespaces = [] | ||
|
||
if request.user.is_superuser: | ||
namespaces = get_namespaces(id_token, api_urls=settings.API_URL, private_clusters=settings.PRIVATE_CLUSTERS) | ||
|
||
if namespaces is None or len(namespaces) == 0: | ||
namespaces = [] | ||
for group in groups: | ||
if str(group) != "MAIA:users": | ||
namespaces.append(str(group).split(":")[-1].lower().replace("_", "-")) | ||
|
||
initial_data = {'user_email': email, 'namespace': namespaces[0] if namespaces else None} | ||
|
||
if request.method == "POST": | ||
form = GPUBookingForm(request.POST, request.FILES) | ||
if form.is_valid(): | ||
form.save() | ||
msg = 'Request for GPU Booking submitted successfully.' | ||
success = True | ||
return redirect("/maia/gpu-booking/my-bookings/") | ||
else: | ||
print(form.errors) | ||
msg = 'Form is not valid' | ||
else: | ||
form = GPUBookingForm(request.POST or None, request.FILES or None, initial=initial_data) | ||
form.fields['namespace'].choices = [(ns, ns) for ns in namespaces] | ||
|
||
return render(request, "accounts/gpu_booking.html", {"dashboard_version": settings.DASHBOARD_VERSION, "form": form, "msg": msg, "success": success}) | ||
|
||
|
||
@login_required(login_url="/maia/login/") | ||
def gpu_booking_info(request): | ||
bookings = GPUBooking.objects.filter(user_email=request.user.email) | ||
|
||
total_days = 0 | ||
for booking in bookings: | ||
total_days += (booking.end_date - booking.start_date).days | ||
return render(request, "accounts/gpu_booking_info.html", {"dashboard_version": settings.DASHBOARD_VERSION, "bookings": bookings, "total_days": total_days}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters