Skip to content

Commit

Permalink
refactor(kserve): make the torchserve configuration loading more resi…
Browse files Browse the repository at this point in the history
…liant (#2995)

The many defaults defined cannot currently be used as if the matching
keys are not present in the properties file, it's loading will fail.

This patch fixes that and also ignores lines starting with # as they
should be.
  • Loading branch information
sgaist authored Mar 6, 2024
1 parent 2e26323 commit 14e8d6f
Showing 1 changed file with 21 additions and 31 deletions.
52 changes: 21 additions & 31 deletions kubernetes/kserve/kserve_wrapper/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,52 +28,42 @@ def parse_config():
model_store: the path in which the .mar file resides
"""
separator = "="
keys = {}
ts_configuration = {}
config_path = os.environ.get("CONFIG_PATH", DEFAULT_CONFIG_PATH)

logging.info(f"Wrapper: loading configuration from {config_path}")

with open(config_path) as f:
for line in f:
if separator in line:
# Find the name and value by splitting the string
name, value = line.split(separator, 1)

# Assign key value pair to dict
# strip() removes white space from the ends of strings
keys[name.strip()] = value.strip()

keys["model_snapshot"] = json.loads(keys["model_snapshot"])
inference_address, management_address, grpc_inference_port, model_store = (
keys["inference_address"],
keys["management_address"],
keys["grpc_inference_port"],
keys["model_store"],
if not line.startswith("#"):
if separator in line:
name, value = line.split(separator, 1)
ts_configuration[name.strip()] = value.strip()

ts_configuration["model_snapshot"] = json.loads(
ts_configuration.get("model_snapshot", "{}")
)

models = keys["model_snapshot"]["models"]
model_names = []
inference_address = ts_configuration.get(
"inference_address", DEFAULT_INFERENCE_ADDRESS
)
management_address = ts_configuration.get(
"management_address", DEFAULT_MANAGEMENT_ADDRESS
)
grpc_inference_port = ts_configuration.get(
"grpc_inference_port", DEFAULT_GRPC_INFERENCE_PORT
)
model_store = ts_configuration.get("model_store", DEFAULT_MODEL_STORE)

# Get all the model_names
for model, value in models.items():
model_names.append(model)
model_names = ts_configuration["model_snapshot"].get("models", {}).keys()

if not inference_address:
inference_address = DEFAULT_INFERENCE_ADDRESS
if not model_names:
model_names = [DEFAULT_MODEL_NAME]
if not inference_address:
inference_address = DEFAULT_INFERENCE_ADDRESS
if not management_address:
management_address = DEFAULT_MANAGEMENT_ADDRESS

inf_splits = inference_address.split(":")
if not grpc_inference_port:
grpc_inference_address = inf_splits[1] + ":" + DEFAULT_GRPC_INFERENCE_PORT
else:
grpc_inference_address = inf_splits[1] + ":" + grpc_inference_port
grpc_inference_address = inf_splits[1] + ":" + grpc_inference_port
grpc_inference_address = grpc_inference_address.replace("/", "")
if not model_store:
model_store = DEFAULT_MODEL_STORE

logging.info(
"Wrapper : Model names %s, inference address %s, management address %s, grpc_inference_address, %s, model store %s",
Expand Down

0 comments on commit 14e8d6f

Please sign in to comment.