Skip to content

Commit

Permalink
Merge branch 'main' of github.com:Small-Bodies-Node/sbsearch
Browse files Browse the repository at this point in the history
  • Loading branch information
mkelley committed Jan 16, 2025
2 parents 185b36f + a3c6d8b commit 715c17f
Show file tree
Hide file tree
Showing 11 changed files with 115 additions and 47 deletions.
2 changes: 1 addition & 1 deletion build_s2.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

set -e
pushd
pushd .

[[ -z $S2PREFIX ]] && echo "Requires env variable S2PREFIX set to desired installation prefix" && exit 1
[[ -z "$PYTHON_ROOT" ]] && PYTHON_ROOT=`python3 -c "import sys; print(sys.exec_prefix)"`
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dependencies = [
"astropy>=4.3",
"astroquery>=0.4.5",
"sbpy>0.3.0",
"sqlalchemy>=1.3,<1.4",
"sqlalchemy>=2.0",
"cython>=0.30",
"extension-helpers",
]
Expand Down
2 changes: 1 addition & 1 deletion sbsearch/ephemeris.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def _ephemeris(
quantities=cls._QUANTITIES,
cache=cache,
)
except ValueError as exc:
except ValueError as exc: # noqa F841
# Dual-listed objects should be queried without CAP/NOFRAG. If this
# was a comet query and the error is "unknown target", try again
# without them.
Expand Down
20 changes: 18 additions & 2 deletions sbsearch/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import time
import logging
from enum import Enum
from typing import Callable, Optional, Union
from typing import Callable, Optional
import numpy as np
from astropy.time import Time


class ElapsedFormatter(logging.Formatter):
Expand Down Expand Up @@ -58,6 +57,23 @@ def setup_logger(
return logger


def setup_search_logger(prefix: str = "SBSearch") -> logging.Logger:
name = f"{prefix} (search)"
logger: logging.Logger = logging.getLogger(name)

# reset handlers, in case already defined
close_logger(name)

formatter = logging.Formatter("%(message)s")

console: logging.StreamHandler = logging.StreamHandler(sys.stderr)
console.setFormatter(formatter)
logger.addHandler(console)
logger.setLevel(logging.INFO)

return logger


def close_logger(name: str = "SBSearch") -> None:
"""Close SBSearch loggers."""
logger: logging.Logger = logging.getLogger(name)
Expand Down
7 changes: 6 additions & 1 deletion sbsearch/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Union, List
import numpy as np
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import declarative_base
from sqlalchemy import (
Column,
BigInteger,
Expand Down Expand Up @@ -171,6 +171,11 @@ class Observation(Base):
seeing: float = Column(Float(32), doc="point source FWHM, arcsec")
airmass: float = Column(Float(32), doc="observation airmass")
maglimit: float = Column(Float(32), doc="detection limit, mag")
mjd_added: float = Column(
Float(32),
index=True,
doc="time observation was added, modified Julian date, UTC",
)

# Common methods.
def test_edges(self) -> None:
Expand Down
42 changes: 25 additions & 17 deletions sbsearch/sbsdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import sqlalchemy as sa
from sqlalchemy.orm import Session
from sqlalchemy.engine import Engine
from sqlalchemy.sql import text

from . import model

__all__ = ['SBSDatabase']
__all__ = ["SBSDatabase"]


SBSD = TypeVar('SBSD', bound='SBSDatabase')
SBSD = TypeVar("SBSD", bound="SBSDatabase")


class SBSDatabase:
Expand All @@ -34,8 +35,9 @@ class SBSDatabase:
"""

def __init__(self, url_or_session: Union[str, Session], *args,
logger_name: str='SBSearch'):
def __init__(
self, url_or_session: Union[str, Session], *args, logger_name: str = "SBSearch"
):
self.session: Session
self.sessionmaker: Union[Session, None]
self.engine: Engine
Expand Down Expand Up @@ -75,36 +77,43 @@ def verify(self):
for name in model.Base.metadata.tables.keys():
if name not in metadata.tables.keys():
missing = True
self.logger.error('{} is missing from database'.format(name))
self.logger.error("{} is missing from database".format(name))

if missing:
self.create()
self.logger.info('Created database tables.')
self.session.execute(text("ANALYZE"))
self.logger.info("Created database tables.")

self.session.commit()

def create_spatial_index(self):
"""Create the spatial term index.
"""Create the spatial term index."""

Generally VACUUM ANALZYE after this.
"""
self.session.execute('''
self.session.execute(
text(
"""
CREATE INDEX IF NOT EXISTS ix_observation_spatial_terms
ON observation
USING GIN (spatial_terms);
''')
"""
)
)
self.session.commit()
self.session.execute(text("ANALYZE observation"))

def drop_spatial_index(self):
"""Drop the spatial term index.
Use this before inserting many observations.
"""
self.session.execute('''
self.session.execute(
text(
"""
DROP INDEX IF EXISTS ix_observation_spatial_terms;
''')
"""
)
)
self.session.commit()

def create(self):
Expand All @@ -118,10 +127,9 @@ def test_db(cls: Type[SBSD], url: str) -> SBSD:
db: SBSD = cls(url)
db.create()

MovingTarget('1P', db).add()
MovingTarget("1P", db).add()
target: MovingTarget = MovingTarget(
'C/1995 O1', db,
secondary_designations=['Hale-Bopp']
"C/1995 O1", db, secondary_designations=["Hale-Bopp"]
)
target.add()

Expand Down
44 changes: 34 additions & 10 deletions sbsearch/sbsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@

import numpy as np
from sqlalchemy.orm import Session, Query
import astropy.units as u
from astropy.time import Time
from astropy.coordinates import SkyCoord

from . import core
from .ephemeris import get_ephemeris_generator, EphemerisGenerator
from .sbsdb import SBSDatabase
from .model import Base, Ephemeris, Observation, Found
from .model import Base, Ephemeris, Observation, Found # noqa F203
from .spatial import ( # pylint: disable=E0611
SpatialIndexer,
polygon_intersects_line,
Expand All @@ -27,7 +27,7 @@
from .target import MovingTarget, FixedTarget
from .exceptions import DesignationError, UnknownSource
from .config import Config
from .logging import ProgressTriangle, setup_logger
from .logging import ProgressTriangle, setup_logger, setup_search_logger


# keep synced with libspatial.cpp, test_spatial.py
Expand All @@ -44,6 +44,11 @@ class IntersectionType(enum.Enum):
class SBSearch:
"""Small-body search tool.
Two loggers are used. The first, under ``logger_name``, publishes general
messages to the console and to a log file (``log``). The second,
``logger_name (search)`` publishes moving target search progress to the
console.
Parameters
----------
Expand Down Expand Up @@ -108,6 +113,7 @@ def __init__(
self.logger: Logger = setup_logger(
filename=log, name=logger_name, level=log_level
)
self.search_logger: Logger = setup_search_logger(logger_name)

def __enter__(self) -> SBSearchObject:
return self
Expand Down Expand Up @@ -169,7 +175,7 @@ def source(self, source: Union[str, Observation]) -> None:
e: Exception
try:
self._source = self.sources[source]
except KeyError as e:
except KeyError as e: # noqa F841
raise UnknownSource(source) from e
else:
if source == Observation:
Expand Down Expand Up @@ -843,6 +849,8 @@ def find_observations_intersecting_line_at_time(
) -> List[Observation]:
"""Find observations intersecting given line at given times.
Progress is published to ``self.search_logger``.
Parameters
----------
Expand All @@ -866,6 +874,7 @@ def find_observations_intersecting_line_at_time(
observations: list of Observation
"""

# normalize inputs for use with spatial submodule
_ra = np.array(ra, float)
_dec = np.array(dec, float)
Expand All @@ -882,11 +891,6 @@ def find_observations_intersecting_line_at_time(
if len(_a) != N or len(_b) != N:
raise ValueError("ra, dec, a, and b must have same length")

observations: List[Observation] = []
n_segment_queries: int = 0
n_matched_observations: int = 0
terms: List[str] # terms corresponding to each segment
segment: slice # slice corresponding to each segment
segments: Tuple[List[str], slice] = core.line_to_segment_query_terms(
self.indexer,
_ra,
Expand All @@ -903,6 +907,20 @@ def find_observations_intersecting_line_at_time(
"intersection at time.",
N,
)
coords: SkyCoord = SkyCoord(_ra, _dec, unit="rad")
arc_length: float = np.sum(coords[:-1].separation(coords[1:])).deg
self.search_logger.info(
"Searching %.1f deg over %.1f days",
arc_length,
np.ptp(mjd),
)

observations: List[Observation] = []
n_segment_queries: int = 0
n_matched_observations: int = 0

terms: List[str] # terms corresponding to each segment
segment: slice # slice corresponding to each segment
for terms, segment in segments:
n_segment_queries += 1
q: Query = self.db.session.query(Observation)
Expand Down Expand Up @@ -970,6 +988,12 @@ def find_observations_intersecting_line_at_time(
)
).all()

self.search_logger.info(
"%d observation%s found",
n_matched_observations if approximate else len(observations),
"" if len(observations) == 1 else "s",
)

if approximate:
self.logger.debug(
"Tested %d segments, matched %d observations.",
Expand All @@ -978,7 +1002,7 @@ def find_observations_intersecting_line_at_time(
)
else:
self.logger.debug(
"Tested %d segments, matched %d observations, " "%d intersections.",
"Tested %d segments, matched %d observations, %d intersections.",
n_segment_queries,
n_matched_observations,
len(observations),
Expand Down
4 changes: 2 additions & 2 deletions sbsearch/target.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Licensed with the 3-clause BSD license. See LICENSE for details.

from abc import ABC, abstractmethod
from typing import Any, List, Optional, Set, Tuple, Type, TypeVar, Union
from typing import List, Optional, Set, Tuple, Type, TypeVar, Union

from sqlalchemy import desc

Expand Down Expand Up @@ -191,7 +191,7 @@ def from_radec(
def coordinates(self) -> SkyCoord:
"""Coordinates as a `astropy.coordinates.SkyCoord` object."""
return self._coords

@property
def ra(self) -> Angle:
"""Right ascension as an `astropy.coordinates.Angle` object."""
Expand Down
9 changes: 5 additions & 4 deletions sbsearch/test/test_sbsdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import sqlalchemy as sa
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.sql import text

from . import fixture_db, Postgresql
from ..sbsdb import SBSDatabase
Expand All @@ -17,14 +18,14 @@ def test_init_session(self):
session = sessionmaker()
db = SBSDatabase(session)
db.create()
db.session.execute('SELECT * FROM designation').fetchall()
db.session.execute(text("SELECT * FROM designation")).fetchall()

def test_verify(self, db):
db.verify()

def test_verify_missing_table(self, db):
db.session.execute('DROP TABLE designation')
db.session.execute(text("DROP TABLE designation"))
with pytest.raises(ProgrammingError):
db.session.execute('SELECT * FROM designation').fetchall()
db.session.execute(text("SELECT * FROM designation")).fetchall()
db.verify()
db.session.execute('SELECT * FROM designation').fetchall()
db.session.execute(text("SELECT * FROM designation")).fetchall()
16 changes: 14 additions & 2 deletions sbsearch/test/test_sbsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import pytest
import warnings
from copy import copy

import numpy as np
from astropy.table import Table
Expand Down Expand Up @@ -464,7 +463,9 @@ def test_find_observations_intersecting_line_with_padding(self, sbs, observation
with pytest.raises(ValueError):
sbs.find_observations_intersecting_line(ra, dec, a=padding, b=[1])

def test_find_observations_intersecting_line_at_time(self, sbs, observations):
def test_find_observations_intersecting_line_at_time(
self, sbs, observations, caplog
):
sbs.add_observations(observations)
sbs.source = "example_survey"

Expand All @@ -478,6 +479,17 @@ def test_find_observations_intersecting_line_at_time(self, sbs, observations):
assert len(found) == 1
assert found[0] == observations[0]

# check the search_log
expected_messages = [
(
"SBSearch (search)",
20,
"Searching 0.9 deg over 0.2 days",
),
("SBSearch (search)", 20, "2 observations found"),
]
assert any([record in expected_messages for record in caplog.record_tuples])

def test_find_observations_intersecting_line_at_time_errors(self, sbs):
sbs.source = "example_survey"
with pytest.raises(ValueError):
Expand Down
Loading

0 comments on commit 715c17f

Please sign in to comment.