Skip to content

Commit

Permalink
Merge pull request #197 from geo-engine/resultdescriptor-with-time
Browse files Browse the repository at this point in the history
Resultdescriptor-with-time
  • Loading branch information
ChristianBeilschmidt authored Oct 28, 2024
2 parents 66cb76b + d9627e9 commit 2f83bc6
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 42 deletions.
5 changes: 3 additions & 2 deletions geoengine/raster.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
'''Raster data types'''
from __future__ import annotations
import json
from typing import AsyncIterator, List, Literal, Optional, Tuple, Union, cast
import numpy as np
import pyarrow as pa
Expand Down Expand Up @@ -227,7 +226,9 @@ def from_ge_record_batch(record_batch: pa.RecordBatch) -> RasterTile2D:
# We know from the backend that there is only one array a.k.a. one column
arrow_array = record_batch.column(0)

time = gety.TimeInterval.from_response(json.loads(metadata[b'time']))
inner_time = geoengine_openapi_client.TimeInterval.from_json(metadata[b'time'])
assert inner_time is not None, "Failed to parse time"
time = gety.TimeInterval.from_response(inner_time)

band = int(metadata[b'band'])

Expand Down
60 changes: 31 additions & 29 deletions geoengine/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def __init__(self,
raise InputException("Time inverval: Start must be <= End")

def is_instant(self) -> bool:
return self.end is None
return self.end is None or self.start == self.end

@property
def time_str(self) -> str:
Expand All @@ -184,36 +184,41 @@ def time_str(self) -> str:
return start_iso + '/' + end_iso

@staticmethod
def from_response(response: Any) -> TimeInterval:
def from_response(response: geoengine_openapi_client.models.TimeInterval) -> TimeInterval:
'''create a `TimeInterval` from an API response'''

if 'start' not in response:
if response.start is None:
raise TypeException('TimeInterval must have a start')

if isinstance(response['start'], int):
start = cast(int, response['start'])
end = cast(int, response['end']) if 'end' in response and response['end'] is not None else None
start = cast(int, response.start)
end = None
if response.end is not None:
end = cast(int, response.end)

return TimeInterval(
np.datetime64(start, 'ms'),
np.datetime64(end, 'ms') if end is not None else None,
)

start_str = cast(str, response['start'])
end_str = cast(str, response['end']) if 'end' in response and response['end'] is not None else None
if start == end:
end = None

return TimeInterval(
datetime.fromisoformat(start_str),
datetime.fromisoformat(end_str) if end_str is not None else None,
np.datetime64(start, 'ms'),
np.datetime64(end, 'ms') if end is not None else None,
)

def __repr__(self) -> str:
return f"TimeInterval(start={self.start}, end={self.end})"

def to_api_dict(self) -> geoengine_openapi_client.TimeInterval:
'''create a openapi `TimeInterval` from self'''
start = self.start.astype('datetime64[ms]').astype(int)
end = self.end.astype('datetime64[ms]').astype(int) if self.end is not None else None

# The openapi Timeinterval does not accept end: None. So we set it to start IF self is an instant.
end = end if end is not None else start

print(self, start, end)

return geoengine_openapi_client.TimeInterval(
start=int(self.start.astype('datetime64[ms]').astype(int)),
end=int(self.end.astype('datetime64[ms]').astype(int)) if self.end is not None else None,
start=int(start),
end=int(end)
)

@staticmethod
Expand Down Expand Up @@ -523,9 +528,8 @@ def from_response_vector(
columns = {name: VectorColumnInfo.from_response(info) for name, info in response.columns.items()}

time_bounds = None
# FIXME: datetime can not represent our min max range
# if 'time' in response and response['time'] is not None:
# time_bounds = TimeInterval.from_response(response['time'])
if response.time is not None:
time_bounds = TimeInterval.from_response(response.time)
spatial_bounds = None
if response.bbox is not None:
spatial_bounds = BoundingBox2D.from_response(response.bbox)
Expand Down Expand Up @@ -580,7 +584,7 @@ def to_api_dict(self) -> geoengine_openapi_client.TypedResultDescriptor:
data_type=self.data_type.to_api_enum(),
spatial_reference=self.spatial_reference,
columns={name: column_info.to_api_dict() for name, column_info in self.columns.items()},
time=self.time_bounds.time_str if self.time_bounds is not None else None,
time=self.time_bounds.to_api_dict() if self.time_bounds is not None else None,
bbox=self.spatial_bounds.to_api_dict() if self.spatial_bounds is not None else None,
resolution=self.spatial_resolution.to_api_dict() if self.spatial_resolution is not None else None,
))
Expand Down Expand Up @@ -687,7 +691,7 @@ def to_api_dict(self) -> geoengine_openapi_client.TypedResultDescriptor:
data_type=self.data_type,
bands=[band.to_api_dict() for band in self.__bands],
spatial_reference=self.spatial_reference,
time=self.time_bounds.time_str if self.time_bounds is not None else None,
time=self.time_bounds.to_api_dict() if self.time_bounds is not None else None,
bbox=self.spatial_bounds.to_api_dict() if self.spatial_bounds is not None else None,
resolution=self.spatial_resolution.to_api_dict() if self.spatial_resolution is not None else None
))
Expand All @@ -701,9 +705,8 @@ def from_response_raster(
bands = [RasterBandDescriptor.from_response(band) for band in response.bands]

time_bounds = None
# FIXME: datetime can not represent our min max range
# if 'time' in response and response['time'] is not None:
# time_bounds = TimeInterval.from_response(response['time'])
if response.time is not None:
time_bounds = TimeInterval.from_response(response.time)
spatial_bounds = None
if response.bbox is not None:
spatial_bounds = SpatialPartition2D.from_response(response.bbox)
Expand Down Expand Up @@ -784,9 +787,8 @@ def from_response_plot(response: geoengine_openapi_client.TypedPlotResultDescrip
spatial_ref = response.spatial_reference

time_bounds = None
# FIXME: datetime can not represent our min max range
# if 'time' in response and response['time'] is not None:
# time_bounds = TimeInterval.from_response(response['time'])
if response.time is not None:
time_bounds = TimeInterval.from_response(response.time)
spatial_bounds = None
if response.bbox is not None:
spatial_bounds = BoundingBox2D.from_response(response.bbox)
Expand Down Expand Up @@ -817,7 +819,7 @@ def to_api_dict(self) -> geoengine_openapi_client.TypedResultDescriptor:
type='plot',
spatial_reference=self.spatial_reference,
data_type='Plot',
time=self.time_bounds.time_str if self.time_bounds is not None else None,
time=self.time_bounds.to_api_dict() if self.time_bounds is not None else None,
bbox=self.spatial_bounds.to_api_dict() if self.spatial_bounds is not None else None
))

Expand Down
4 changes: 4 additions & 0 deletions geoengine/workflow_builder/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,9 @@ def offset_scale_dict(key_or_value: Optional[Union[float, str]]) -> Dict[str, An
if isinstance(key_or_value, float):
return {"type": "constant", "value": key_or_value}

if isinstance(key_or_value, int):
return {"type": "constant", "value": float(key_or_value)}

# TODO: incorporate `domain` field
return {"type": "metadataKey", "key": key_or_value}

Expand Down Expand Up @@ -800,6 +803,7 @@ def name(self) -> str:
return 'VectorExpression'

def to_dict(self) -> Dict[str, Any]:
output_column_dict = None
if isinstance(self.output_column, GeoVectorDataType):
output_column_dict = {
"type": "geometry",
Expand Down
11 changes: 5 additions & 6 deletions tests/test_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class RasterTests(unittest.TestCase):
test_data: ge.RasterTile2D

def setUp(self) -> None:
time = np.datetime64(datetime(2020, 1, 1, 0, 0, 0, 0), 'ms')
time = ge.TimeInterval(start=datetime(2020, 1, 1, 0, 0, 0, 0))
raster_data = rio.open("tests/responses/ndvi.tiff")
bounds = raster_data.bounds
no_data_value = 255 # raster_data.nodata is 0 which is propably wrong)
Expand All @@ -32,9 +32,7 @@ def setUp(self) -> None:
y_pixel_size=-22.5,
),
crs="EPSG:4326",
time=ge.TimeInterval(
start=time
),
time=time,
band=0,
)

Expand Down Expand Up @@ -96,7 +94,7 @@ def test_to_xarray(self) -> None:
self.assertEqual(origin_y, self.test_data.geo_transform.y_max + self.test_data.geo_transform.y_pixel_size / 2)

def test_from_ge_record_batch(self) -> None:
time = datetime(2020, 1, 1, 0, 0, 0, 0).isoformat()
time = np.datetime64(datetime(2020, 1, 1, 0, 0, 0, 0), 'ms').astype(np.int64)
raster_data = rio.open("tests/responses/ndvi.tiff")
bounds = raster_data.bounds
no_data_value = 255 # raster_data.nodata is 0 which is propably wrong)
Expand All @@ -118,7 +116,8 @@ def test_from_ge_record_batch(self) -> None:
"ySize": str(raster_data.shape[0]),
"spatialReference": "EPSG:4326",
"time": json.dumps({
"start": time
"start": int(time),
"end": int(time)
}),
"band": "0",
}
Expand Down
10 changes: 5 additions & 5 deletions tests/test_workflow_raster_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def __init__(self):

self.__tiles = []

for time in ["2014-01-01T00:00:00", "2014-01-02T00:00:00"]:
for time in [datetime(2014, 1, 1, 0, 0, 0), datetime(2014, 1, 2, 0, 0, 0)]:
for tiles in read_data():
self.__tiles.append(arrow_bytes(tiles, time, 0))
self.__tiles.append(arrow_bytes(tiles, ge.TimeInterval(start=time), 0))

async def __aenter__(self):
return self
Expand Down Expand Up @@ -62,7 +62,7 @@ def read_data() -> List[xr.DataArray]:
return parts


def arrow_bytes(data: xr.DataArray, time: str, band: int) -> bytes:
def arrow_bytes(data: xr.DataArray, time: ge.TimeInterval, band: int) -> bytes:
'''Convert a xarray.DataArray into an Arrow record batch within an IPC file'''

array = pa.array(data.to_numpy().reshape(-1))
Expand All @@ -80,8 +80,8 @@ def arrow_bytes(data: xr.DataArray, time: str, band: int) -> bytes:
"ySize": "4",
"spatialReference": "EPSG:4326",
"time": json.dumps({
"start": time,
"end": time,
"start": int(time.start.astype('datetime64[ms]').astype(int)),
"end": int(time.start.astype('datetime64[ms]').astype(int))
}),
"band": str(band),
})
Expand Down

0 comments on commit 2f83bc6

Please sign in to comment.