Skip to content

Commit

Permalink
Remove random collection sampling from test module
Browse files Browse the repository at this point in the history
  • Loading branch information
mfisher87 committed Aug 20, 2024
1 parent 5d708c8 commit c8ba3d4
Showing 1 changed file with 48 additions and 34 deletions.
82 changes: 48 additions & 34 deletions tests/integration/test_onprem_download.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
import logging
import os
import random
import shutil
import unittest
from collections import TypedDict
from pathlib import Path
from typing import TypedDict

import earthaccess
import pytest
from earthaccess import Auth, DataCollections, DataGranules, Store
from earthaccess import Auth, DataGranules, Store

from .sample import get_sample_granules

logger = logging.getLogger(__name__)


class TestParam(TypedDict):
daac_name: str
provider_name: str

# How many of the top collections we will test, e.g. top 3 collections
top_n_collections: int
n_for_top_collections: int

# How many granules we will query
granules_count: int
Expand All @@ -33,26 +32,26 @@ class TestParam(TypedDict):

daacs_list: list[TestParam] = [
{
"short_name": "NSIDC",
"top_n_collections": 3,
"granules_count": 100,
"granules_sample_size": 2,
"granules_max_size_mb": 100,
},
{
"short_name": "GES_DISC",
"top_n_collections": 2,
"granules_count": 100,
"granules_sample_size": 2,
"granules_max_size_mb": 130,
},
{
"short_name": "LPDAAC",
"top_n_collections": 2,
"provider_name": "NSIDC_ECS",
"n_for_top_collections": 3,
"granules_count": 100,
"granules_sample_size": 2,
"granules_max_size_mb": 100,
},
# {
# "provider_name": "GES_DISC",
# "top_n_collections": 2,
# "granules_count": 100,
# "granules_sample_size": 2,
# "granules_max_size_mb": 130,
# },
# {
# "provider_name": "LPDAAC",
# "top_n_collections": 2,
# "granules_count": 100,
# "granules_sample_size": 2,
# "granules_max_size_mb": 100,
# },
]

assertions = unittest.TestCase("__init__")
Expand All @@ -70,7 +69,25 @@ class TestParam(TypedDict):
store = Store(auth)


def top_collections_for_provider(provider: str, *, n: int) -> list[str]:
"""Return the top collections for this provider.
Local cache is used as the source for this list. Run
`./popular_collections/generate.py` to refresh it!
TODO: Skip / exclude collections that have a EULA; filter them out in this function
or use a pytest skip/xfail mark?
"""
popular_collections_dir = Path(__file__).parent / "popular_collections"
popular_collections_file = popular_collections_dir / f"{provider}.txt"
with open(popular_collections_file) as f:
popular_collections = f.read().splitlines()

return popular_collections[:n]


def supported_collection(data_links):
"""What is the purpose of this?"""
for url in data_links:
if "podaac-tools.jpl.nasa.gov/drive" in url:
return False
Expand All @@ -80,23 +97,20 @@ def supported_collection(data_links):
@pytest.mark.parametrize("daac", daacs_list)
def test_earthaccess_can_download_onprem_collection_granules(daac):
"""Tests that we can download on-premises collections using HTTPS links."""
daac_shortname = daac["short_name"]
collections_count = daac["collections_count"]
collections_sample_size = daac["collections_sample_size"]
provider = daac["provider_name"]
n_for_top_collections = daac["n_for_top_collections"]

granules_count = daac["granules_count"]
granules_sample_size = daac["granules_sample_size"]
granules_max_size = daac["granules_max_size_mb"]

collection_query = DataCollections().data_center(daac_shortname).cloud_hosted(False)
hits = collection_query.hits()
logger.info(f"On-premises collections for {daac_shortname}: {hits}")
collections = collection_query.get(collections_count)
assertions.assertGreater(len(collections), collections_sample_size)
# We sample n cloud hosted collections from the results
random_collections = random.sample(collections, collections_sample_size)
logger.info(f"Sampled {len(random_collections)} collections")
for collection in random_collections:
concept_id = collection.concept_id()
top_collections = top_collections_for_provider(
provider,
n=n_for_top_collections,
)
logger.info(f"On-premises collections for {provider}: {len(top_collections)}")

for concept_id in top_collections:
granule_query = DataGranules().concept_id(concept_id)
total_granules = granule_query.hits()
granules = granule_query.get(granules_count)
Expand Down

0 comments on commit c8ba3d4

Please sign in to comment.