Skip to content

Commit

Permalink
Re-implement ObscoreExporter to use general query (DM-47980)
Browse files Browse the repository at this point in the history
This allows returning of the dimension records in the same query and
avoids additional queries for records.
  • Loading branch information
andy-slac committed Dec 19, 2024
1 parent feec90d commit d45cf21
Showing 1 changed file with 64 additions and 75 deletions.
139 changes: 64 additions & 75 deletions python/lsst/dax/obscore/obscore_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@
import io
from collections.abc import Iterator
from functools import cache
from typing import Any, cast
from typing import Any

import astropy.io.votable
import astropy.table
import felis.datamodel
import pyarrow
import sqlalchemy
import yaml
from lsst.daf.butler import Butler, DataCoordinate, Dimension, Registry, ddl
from lsst.daf.butler import Butler, DataCoordinate, ddl
from lsst.daf.butler.formatters.parquet import arrow_to_numpy
from lsst.daf.butler.registry.obscore import (
ExposureRegionFactory,
Expand Down Expand Up @@ -198,75 +198,39 @@ def close(self) -> None:


class _ExposureRegionFactory(ExposureRegionFactory):
"""Find exposure region from a matching visit dimensions records."""

def __init__(self, registry: Registry):
self.registry = registry
self.universe = registry.dimensions

# Maps instrument and visit ID to a region
self._visit_regions: dict[str, dict[int, Region]] = {}
# Maps instrument+visit+detector to a region
self._visit_detector_regions: dict[str, dict[tuple[int, int], Region]] = {}
# Maps instrument and exposure ID to a visit ID
self._exposure_to_visit: dict[str, dict[int, int]] = {}

def exposure_region(self, dataId: DataCoordinate) -> Region | None:
# Docstring is inherited from a base class.
registry = self.registry
instrument = cast(str, dataId["instrument"])

exposure_to_visit = self._exposure_to_visit.get(instrument)
if exposure_to_visit is None:
self._exposure_to_visit[instrument] = exposure_to_visit = {}
# Read complete relation between visits and exposures. There could
# be multiple visits defined per exposure, but they are supposed to
# have the same region, so we take one of them at random.
records = registry.queryDimensionRecords("visit_definition", instrument=instrument)
for record in records:
exposure_to_visit[record.exposure] = record.visit
_LOG.debug("read %d exposure-to-visit records", len(exposure_to_visit))

# map exposure to a visit
exposure = cast(int, dataId["exposure"])
visit = exposure_to_visit.get(exposure)
if visit is None:
return None

universe = self.universe
detector_dimension = cast(Dimension, universe["detector"])
if str(detector_dimension) in dataId:
visit_detector_regions = self._visit_detector_regions.get(instrument)

if visit_detector_regions is None:
self._visit_detector_regions[instrument] = visit_detector_regions = {}

# Read all visits, there is a chance we need most of them
# anyways, and trying to filter by dataset type and collection
# makes it much slower.
records = registry.queryDimensionRecords("visit_detector_region", instrument=instrument)
for record in records:
visit_detector_regions[(record.visit, record.detector)] = record.region
_LOG.debug("read %d visit-detector regions", len(visit_detector_regions))
"""Exposure region factory that returns an existing region, region is
specified via `set` method, which should be called before calling
record factory.
"""

detector = cast(int, dataId["detector"])
return visit_detector_regions.get((visit, detector))
def __init__(self) -> None:
self._data_id: DataCoordinate | None = None
self._region: Region | None = None

else:
visit_regions = self._visit_regions.get(instrument)
def set(self, data_id: DataCoordinate, region: Region) -> None:
"""Set region for specified DataId.
if visit_regions is None:
self._visit_regions[instrument] = visit_regions = {}
Parameters
----------
data_id : `~lsst.daf.butler.DataCoordinate`
Data ID that will be matched against parameter of
`exposure_region`.
region : `Region`
Corresponding region.
"""
self._data_id = data_id
self._region = region

# Read all visits, there is a chance we need most of them
# anyways, and trying to filter by dataset type and collection
# makes it much slower.
records = registry.queryDimensionRecords("visit", instrument=instrument)
for record in records:
visit_regions[record.id] = record.region
_LOG.debug("read %d visit regions", len(visit_regions))
def reset(self) -> None:
"""Reset DataId and region to default values."""
self._data_id = None
self._region = None

return visit_regions.get(visit)
def exposure_region(self, dataId: DataCoordinate) -> Region | None:
# Docstring inherited.
if dataId == self._data_id:
return self._region
return None


class ObscoreExporter:
Expand All @@ -290,10 +254,10 @@ def __init__(self, butler: Butler, config: ExporterConfig):

self.schema = self._make_schema(schema.table_spec)

exposure_region_factory = _ExposureRegionFactory(self.butler.registry)
self._exposure_region_factory = _ExposureRegionFactory()
universe = self.butler.dimensions
self.record_factory = RecordFactory(
config, schema, universe, spatial_plugins, exposure_region_factory
config, schema, universe, spatial_plugins, self._exposure_region_factory
)

def to_parquet(self, output: str) -> None:
Expand Down Expand Up @@ -494,27 +458,52 @@ def _make_record_batches(
# Want an empty default to match everything.
where_clauses = [WhereBind(where="")]

# Region can come from either visit or visit_detector_region. If we
# are looking at exposure then visit will be joined by the query
# system.
dataset_type = self.butler.get_dataset_type(dataset_type_name)
region_key: str | None = None
if "exposure" in dataset_type.dimensions or "visit" in dataset_type.dimensions:
if "detector" in dataset_type.dimensions:
region_key = "visit_detector_region.region"
else:
region_key = "visit.region"

with self.butler.query() as query:
for where_clause in where_clauses:
where_query = query

if where_clause.extra_dims:
query = query.join_dimensions(where_clause.extra_dims)
where_query = where_query.join_dimensions(where_clause.extra_dims)

if where_clause.where:
_LOG.verbose("Processing query with constraint %s", where_clause)
query = query.where(where_clause.where, bind=where_clause.bind)
where_query = where_query.where(where_clause.where, bind=where_clause.bind)

where_query = where_query.join_dataset_search(dataset_type_name, collections=collections)

region_args = [region_key] if region_key else []
result = where_query.general(
dataset_type.dimensions,
*region_args,
dataset_fields={dataset_type_name: ...},
find_first=True,
)

refs = query.datasets(dataset_type_name, collections=collections, find_first=True)
# We need dimension records.
result = result.with_dimension_records()

if limit is not None:
refs = refs.limit(limit)
result = result.limit(limit)

# need dimension records
count = 0
for ref in refs.with_dimension_records():
for dataId, (ref,), raw_row in result.iter_tuples(dataset_type):
dataId = ref.dataId
_LOG.debug("New record, dataId=%s", dataId.mapping)
region = raw_row[region_key] if region_key else None
_LOG.debug("New record, dataId=%s region=%s", dataId.mapping, region)
# _LOG.debug("New record, records=%s", dataId.records)

self._exposure_region_factory.set(dataId, region)
record = self.record_factory(ref)
if record is None:
continue
Expand Down

0 comments on commit d45cf21

Please sign in to comment.