Skip to content

Commit

Permalink
Merge pull request #194 from geo-engine/ml_pipeline
Browse files Browse the repository at this point in the history
ml model pipeline (wip)
  • Loading branch information
jdroenner authored Sep 9, 2024
2 parents 546e2af + 3c14f27 commit c889f51
Show file tree
Hide file tree
Showing 5 changed files with 313 additions and 3 deletions.
207 changes: 207 additions & 0 deletions examples/ml_pipeline.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions geoengine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .layers import Layer, LayerCollection, LayerListing, LayerCollectionListing, \
LayerId, LayerCollectionId, LayerProviderId, \
layer_collection, layer
from .ml import register_ml_model, MlModelConfig, SerializableModel
from .permissions import add_permission, remove_permission, add_role, remove_role, assign_role, revoke_role, \
ADMIN_ROLE_ID, REGISTERED_USER_ROLE_ID, ANONYMOUS_USER_ROLE_ID, Permission, Resource, UserId, RoleId
from .tasks import Task, TaskId
Expand Down
58 changes: 58 additions & 0 deletions geoengine/ml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
'''
Util functions for machine learning
'''

from pathlib import Path
import tempfile
from typing import Protocol
from dataclasses import dataclass
from geoengine_openapi_client.models import MlModelMetadata, MlModel
import geoengine_openapi_client
from geoengine.auth import get_session
from geoengine.datasets import UploadId


# pylint: disable=invalid-name
class SerializableModel(Protocol):
'''A protocol for serializable models'''

def SerializeToString(self) -> bytes:
...


@dataclass
class MlModelConfig:
'''Configuration for an ml model'''
name: str
metadata: MlModelMetadata
file_name: str = "model.onnx"
display_name: str = "My Ml Model"
description: str = "My Ml Model Description"


def register_ml_model(onnx_model: SerializableModel,
model_config: MlModelConfig,
upload_timeout: int = 3600,
register_timeout: int = 60):
'''Uploads an onnx file and registers it as an ml model'''

session = get_session()

with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
with tempfile.TemporaryDirectory() as temp_dir:
file_name = Path(temp_dir) / model_config.file_name

with open(file_name, 'wb') as file:
file.write(onnx_model.SerializeToString())

uploads_api = geoengine_openapi_client.UploadsApi(api_client)
response = uploads_api.upload_handler([str(file_name)],
_request_timeout=upload_timeout)

upload_id = UploadId.from_response(response)

ml_api = geoengine_openapi_client.MLApi(api_client)

model = MlModel(name=model_config.name, upload=str(upload_id), metadata=model_config.metadata,
display_name=model_config.display_name, description=model_config.description)
ml_api.add_ml_model(model, _request_timeout=register_timeout)
47 changes: 45 additions & 2 deletions geoengine/workflow_builder/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ def __init__(self,
self.map_no_data = map_no_data

def name(self) -> str:
return 'Expression'
return 'BandwiseExpression'

def to_dict(self) -> Dict[str, Any]:
params = {
Expand Down Expand Up @@ -1157,7 +1157,7 @@ def __init__(self,
source: RasterOperator,
aggregate: BandNeighborhoodAggregateParams
):
'''Creates a new RasterStacker operator.'''
'''Creates a new BandNeighborhoodAggregate operator.'''
self.source = source
self.aggregate = aggregate

Expand Down Expand Up @@ -1250,3 +1250,46 @@ def to_dict(self) -> Dict[str, Any]:
"type": "average",
"windowSize": self.window_size
}


class Onnx(RasterOperator):
'''Onnx ML operator.'''

source: RasterOperator
model: str

# pylint: disable=too-many-arguments
def __init__(self,
source: RasterOperator,
model: str
):
'''Creates a new Onnx operator.'''
self.source = source
self.model = model

def name(self) -> str:
return 'Onnx'

def to_dict(self) -> Dict[str, Any]:
return {
"type": self.name(),
"params": {
"model": self.model
},
"sources": {
"raster": self.source.to_dict()
}
}

@classmethod
def from_operator_dict(cls, operator_dict: Dict[str, Any]) -> 'Onnx':
if operator_dict["type"] != "Onnx":
raise ValueError("Invalid operator type")

source = RasterOperator.from_operator_dict(operator_dict["sources"]["raster"])
model = operator_dict["params"]["model"]

return Onnx(
source=source,
model=model
)
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package_dir =
packages = find:
python_requires = >=3.8
install_requires =
geoengine-openapi-client == 0.0.11
geoengine-openapi-client == 0.0.12
geopandas >=0.9,<0.15
matplotlib >=3.5,<3.8
numpy >=1.21,<2
Expand All @@ -35,6 +35,7 @@ install_requires =
xarray >=0.19,<2024.3
urllib3 >= 2.0, < 2.3
pydantic >= 1.10.5, < 2
skl2onnx >=1.17,<2

[options.extras_require]
dev =
Expand Down

0 comments on commit c889f51

Please sign in to comment.