diff --git a/viewer/squonk_job_file_transfer.py b/viewer/squonk_job_file_transfer.py index da97976d..62da460a 100644 --- a/viewer/squonk_job_file_transfer.py +++ b/viewer/squonk_job_file_transfer.py @@ -4,7 +4,7 @@ import os import urllib.parse from pathlib import Path -from typing import Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple from celery.utils.log import get_task_logger from django.conf import settings @@ -100,7 +100,7 @@ def process_file_transfer(auth_token, job_transfer_id): def validate_file_transfer_files( request, -) -> Tuple[Dict[str, str], List[Path], List[Path]]: +) -> Tuple[Optional[Dict[str, str]], Optional[List[Path]], Optional[List[Path]]]: """Check the request and return a list of proteins and/or computed molecule file path references (paths relative to the media directory). @@ -121,9 +121,9 @@ def validate_file_transfer_files( list of validated computed molecules (ComputedMolecule) """ - logger.info('+ Validating file transfer files...') + target_id = request.data['target'] + logger.info('+ Validating file transfer files ()...') - error: Dict[str, str] = {} protein_files: List[Path] = [] compound_files: List[Path] = [] @@ -133,21 +133,24 @@ def validate_file_transfer_files( p.strip() for p in request.data['proteins'].split(',') ] for protein_path_and_file in protein_paths_and_files: - # It's a filename if protein_path_and_file.endswith('_apo-desolv.pdb'): - if SiteObservation.objects.filter( - apo_desolv_file=protein_path_and_file - ).first(): + if not ( + s_ob := SiteObservation.objects.filter( + apo_desolv_file=protein_path_and_file + ).first() + ): + return tfr_validation_error( + f'Unknown Protein: {protein_path_and_file}', + status.HTTP_404_NOT_FOUND, + ) + + if s_ob.experiment.experiment_upload.target.id == target_id: protein_files.append(Path(protein_path_and_file)) else: - error['message'] = f'Unknown Protein: {protein_path_and_file}' - error['status'] = status.HTTP_404_NOT_FOUND - return error, protein_files, compound_files - - if not protein_files: - error['message'] = 'API expects a list of comma-separated protein codes' - error['status'] = status.HTTP_404_NOT_FOUND - return error, protein_files, compound_files + return tfr_validation_error( + f'Protein does not belong to Target: {protein_path_and_file}', + status.HTTP_400_BAD_REQUEST, + ) logger.info( "+ Validated proteins (SiteObservations) [%d]", @@ -159,36 +162,43 @@ def validate_file_transfer_files( p.strip() for p in request.data['compounds'].split(',') ] for compound_path_and_file in compound_paths_and_files: - if SiteObservation.objects.filter( + if not SiteObservation.objects.filter( ligand_mol=compound_path_and_file ).first(): + return tfr_validation_error( + f'Unknown Compound: {compound_path_and_file}', + status.HTTP_404_NOT_FOUND, + ) + + if s_ob.experiment.experiment_upload.target.id == target_id: compound_files.append(Path(compound_path_and_file)) else: - error['message'] = f'Unknown Compound: {compound_path_and_file}' - error['status'] = status.HTTP_404_NOT_FOUND - return error, protein_files, compound_files - - if not compound_files: - error['message'] = 'API expects a list of comma-separated compound names' - error['status'] = status.HTTP_400_BAD_REQUEST - return error, protein_files, compound_files + return tfr_validation_error( + f'Compound does not belong to Target: {compound_path_and_file}', + status.HTTP_400_BAD_REQUEST, + ) logger.info( "+ Validated compounds (SiteObservations) [%d]", len(compound_files), ) - if protein_files or compound_files: - logger.info( - "- Validated file transfer files (%d, %d)", - len(protein_files), - len(compound_files), + if not protein_files and not compound_files: + return tfr_validation_error( + 'A valid set of protein codes and/or a list of valid compound names must be provided', + status.HTTP_400_BAD_REQUEST, ) - return error, protein_files, compound_files - error['message'] = ( - 'A valid set of protein codes and/or a list of valid' - ' compound names must be provided' + logger.info( + "- Validated file transfer files (%d, %d)", + len(protein_files), + len(compound_files), ) - error['status'] = status.HTTP_400_BAD_REQUEST - return error, protein_files, compound_files + return None, protein_files, compound_files + + +def tfr_validation_error( + error: str, status_code: int +) -> Tuple[Dict[str, Any], None, None]: + """Returns the error and HTTP status code as a tuple for a response.""" + return {'message': error, 'status': status_code}, None, None