Skip to content

Commit

Permalink
Support various java gateways
Browse files Browse the repository at this point in the history
  • Loading branch information
scorebot committed Sep 14, 2024
1 parent 81c9244 commit 425ad62
Show file tree
Hide file tree
Showing 16 changed files with 290 additions and 102 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ _PyPMML_ is a Python PMML scoring library, it really is the Python API for [PMML

## Dependencies
- [Py4J](https://www.py4j.org/)

- or
- [JPype](https://www.jpype.org/)

## Installation

```bash
Expand Down Expand Up @@ -126,7 +128,7 @@ See the [PyPMML-Spark](https://github.com/autodeployai/pypmml-spark) project. _P
See the [AI-Serving](https://github.com/autodeployai/ai-serving) project. _AI-Serving_ is serving AI/ML models in the open standard formats PMML and ONNX with both HTTP (REST API) and gRPC endpoints.

## Deploy and Manage AI/ML models at scale
See the [DaaS](https://www.autodeployai.com/) system that deploys AI/ML models in production at scale on Kubernetes.
See the [DaaS](https://www.autodeploy.ai/) system that deploys AI/ML models in production at scale on Kubernetes.

## Support
If you have any questions about the _PyPMML_ library, please open issues on this repository.
Expand Down
7 changes: 4 additions & 3 deletions pypmml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2017-2019 AutoDeploy AI
# Copyright (c) 2017-2024 AutoDeployAI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -19,7 +19,8 @@
"""

from pypmml.model import Model
from pypmml.base import PmmlError
from pypmml.jvm import PMMLError
from pypmml.base import PMMLContext
from pypmml.version import __version__

__all__ = ['Model', 'PmmlError']
__all__ = ['Model', 'PMMLError', 'PMMLContext']
115 changes: 50 additions & 65 deletions pypmml/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2017-2019 AutoDeploy AI
# Copyright (c) 2017-2024 AutoDeployAI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,51 +14,32 @@
# limitations under the License.
#

import os
from os import path
from threading import RLock

from py4j.java_collections import JavaArray
from py4j.java_gateway import JavaObject, JavaGateway, launch_gateway, GatewayParameters


def _java2py(r):
if isinstance(r, JavaArray):
return [_java2py(x) for x in r]
elif isinstance(r, JavaObject):
cls_name = r.getClass().getName()
if cls_name == 'scala.Some':
return r.get()
elif cls_name == 'scala.None$':
return None
elif cls_name == 'scala.Enumeration$Val':
return r.toString()
return r


def call_java_func(func, *args):
""" Call Java Function """
return _java2py(func(*args))

from .jvm import JVMGateway

class PMMLContext(object):
_gateway = None
_jvm = None
_gateway: JVMGateway = None
_active_pmml_context = None
_lock = RLock()

def __init__(self, gateway=None):
PMMLContext._ensure_initialized(self, gateway=gateway)
def __init__(self, gateway_instance=None, gateway="py4j", java_opts=None, java_path=None):
PMMLContext._ensure_initialized(
self,
gateway_instance=gateway_instance,
gateway=gateway,
java_opts=java_opts,
java_path=java_path)

@classmethod
def _ensure_initialized(cls, instance, gateway=None):
def _ensure_initialized(cls, instance, gateway_instance=None, gateway="py4j", java_opts=None, java_path=None):
"""
Checks whether a Gateway of py4j is initialized or not.
Checks whether a Gateway of JVM is initialized or not.
"""
with PMMLContext._lock:
if not PMMLContext._gateway:
PMMLContext._gateway = gateway or cls.launch_gateway()
PMMLContext._jvm = PMMLContext._gateway.jvm
PMMLContext._gateway = gateway_instance or cls.launch_gateway(
gateway=gateway, java_opts=java_opts, java_path=java_path
)

if instance:
if PMMLContext._active_pmml_context and PMMLContext._active_pmml_context != instance:
Expand All @@ -68,42 +49,38 @@ def _ensure_initialized(cls, instance, gateway=None):
PMMLContext._active_pmml_context = instance

@classmethod
def getOrCreate(cls):
def getOrCreate(cls, gateway="py4j", java_opts=None, java_path=None) -> 'PMMLContext':
"""
Get or instantiate a PMMLContext and register it as a singleton object.
:param java_opts: an array of extra options to pass to Java (the classpath
should be specified using the `classpath` parameter, not `java_opts`.)
:param java_path: If None, JVM will use $JAVA_HOME/bin/java if $JAVA_HOME
is defined, otherwise it will use "java".
:param gateway: JVM gateway engine, support one of ["py4j", "jpype"]
"""
with PMMLContext._lock:
if PMMLContext._active_pmml_context is None:
PMMLContext()
PMMLContext(gateway=gateway, java_opts=java_opts, java_path=java_path)
return PMMLContext._active_pmml_context

@classmethod
def launch_gateway(cls, javaopts=[], java_path=None):
def launch_gateway(cls, gateway="py4j", java_opts=None, java_path=None) -> 'JVMGateway':
"""Launch a `Gateway` in a new Java process.
:param javaopts: an array of extra options to pass to Java (the classpath
should be specified using the `classpath` parameter, not `javaopts`.)
:param java_path: If None, Py4J will use $JAVA_HOME/bin/java if $JAVA_HOME
:param gateway: JVM gateway engine, support one of ["py4j", "jpype"]
:param java_opts: an array of extra options to pass to Java (the classpath
should be specified using the `classpath` parameter, not `java_opts`.)
:param java_path: If None, JVM will use $JAVA_HOME/bin/java if $JAVA_HOME
is defined, otherwise it will use "java".
:return: An object of `Gateway`
"""
jars_dir = os.environ["PYPMML_JARS_DIR"] if "PYPMML_JARS_DIR" in os.environ else \
path.join(path.dirname(path.abspath(__file__)), 'jars')
launch_classpath = path.join(jars_dir, "*")

if not javaopts:
java_opts = os.environ.get("JAVA_OPTS")
if java_opts:
javaopts = java_opts.split()

# Fix IllegalAccessError: cannot access class jdk.internal.math.FloatingDecimal
javaopts.append("--add-exports")
javaopts.append("java.base/jdk.internal.math=ALL-UNNAMED")

_port = launch_gateway(classpath=launch_classpath, javaopts=javaopts, java_path=java_path, die_on_exit=True)
gateway = JavaGateway(
gateway_parameters=GatewayParameters(port=_port,
auto_convert=True))
return gateway
if isinstance(gateway, str) and gateway.lower() == "jpype":
from .jvm import JPypeGateway
jvm_gateway = JPypeGateway()
else:
from .jvm import Py4jGateway
jvm_gateway = Py4jGateway()
jvm_gateway.launch_gateway(java_opts=java_opts, java_path=java_path)
return jvm_gateway

@classmethod
def shutdown(cls):
Expand All @@ -114,13 +91,21 @@ def shutdown(cls):
if PMMLContext._gateway:
PMMLContext._gateway.shutdown()
PMMLContext._gateway = None
PMMLContext._jvm = None
PMMLContext._active_pmml_context = None

def call_java_static_func(self, class_name, func_name, *args):
return self._gateway.call_java_static_func(class_name, func_name, *args)

def call_java_func(self, func, *args):
return self._gateway.call_java_func(func, *args)

class PmmlError(Exception):
"""Base exception of PyPMML"""
def detach(self, java_model):
if self._gateway:
self._gateway.detach(java_model)

@classmethod
def gateway(cls):
return cls._gateway.name() if cls._gateway is not None else None

class JavaModelWrapper(object):
"""
Expand All @@ -131,11 +116,11 @@ def __init__(self, java_model):
self._java_model = java_model

def __del__(self):
if self._pc._gateway:
self._pc._gateway.detach(self._java_model)
if self._pc:
self._pc.detach(self._java_model)

def call(self, name, *a):
return call_java_func(getattr(self._java_model, name), *a)
def call(self, name, *args):
return self._pc.call_java_func(getattr(self._java_model, name), *args)

def __str__(self):
return self.call('toString')
Expand Down
2 changes: 1 addition & 1 deletion pypmml/elements.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2017-2019 AutoDeploy AI
# Copyright (c) 2017-2024 AutoDeployAI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Binary file removed pypmml/jars/pmml4s_2.13-1.0.2.jar
Binary file not shown.
Binary file added pypmml/jars/pmml4s_2.13-1.0.3.jar
Binary file not shown.
Binary file not shown.
Binary file removed pypmml/jars/scala-parser-combinators_2.13-1.1.2.jar
Binary file not shown.
Binary file added pypmml/jars/scala-reflect-2.13.14.jar
Binary file not shown.
Binary file removed pypmml/jars/scala-reflect-2.13.8.jar
Binary file not shown.
Loading

0 comments on commit 425ad62

Please sign in to comment.