Skip to content

Commit

Permalink
Token Authorization fixes (#3192)
Browse files Browse the repository at this point in the history
* test

* token authorization integration

* env false regression cpu ci

* testing ci

* testing newman

* fix newman tests

* testing pytest

* testing cmd arg

* pytest fixes

* fixing tests

* doc update

* spell check

* fixing priority between config file and cmd

* test fixes

* removing unneeded files

* Delete unneeded files

* review fixes

* removing comments

* adding doc clarification and new test

* changes to docs

* adding new tests

* fixing merge conflict

* format fix

* format fixes

* addressing comments

* fixing merge conflict

* fixing merge conflict

* fixing merge conflict

* fix merge conflict

* doc update

* fixing format

* fix to benchmarks

---------

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
udaij12 and Ubuntu authored Jun 13, 2024
1 parent 9336ad2 commit f918cd1
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 4,397 deletions.
4 changes: 2 additions & 2 deletions benchmarks/utils/system_under_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def start(self):
click.secho("*Starting local Torchserve instance...", fg="green")

ts_cmd = (
f"torchserve --start --model-store {self.execution_params['tmp_dir']}/model_store --model-api-enabled --disable-token"
f"torchserve --start --model-store {self.execution_params['tmp_dir']}/model_store --model-api-enabled --disable-token "
f"--workflow-store {self.execution_params['tmp_dir']}/wf_store "
f"--ts-config {self.execution_params['tmp_dir']}/benchmark/conf/{self.execution_params['config_properties_name']} "
f" > {self.execution_params['tmp_dir']}/benchmark/logs/model_metrics.log"
Expand Down Expand Up @@ -195,7 +195,7 @@ def start(self):
f"docker run {self.execution_params['docker_runtime']} {backend_profiling} --name ts --user root -p "
f"127.0.0.1:{inference_port}:{inference_port} -p 127.0.0.1:{management_port}:{management_port} "
f"-v {self.execution_params['tmp_dir']}:/tmp {enable_gpu} -itd {docker_image} "
f'"torchserve --start --model-store /home/model-server/model-store --model-api-enabled --disable-token'
f'"torchserve --start --model-store /home/model-server/model-store --model-api-enabled --disable-token '
f"\--workflow-store /home/model-server/wf-store "
f"--ts-config /tmp/benchmark/conf/{self.execution_params['config_properties_name']} > "
f'/tmp/benchmark/logs/model_metrics.log"'
Expand Down
10 changes: 6 additions & 4 deletions docs/token_authorization_api.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# TorchServe token authorization API

Torchserve now supports token authorization by default.
TorchServe now enforces token authorization by default


## How to set and disable Token Authorization
* Global environment variable: use `TS_DISABLE_TOKEN_AUTHORIZATION` and set to `true` to disable and `false` to enable token authorization. Note that `enable_envvars_config=true` must be set in config.properties for global environment variables to be used
* Command line: Command line can only be used to disable token authorization by adding the `--disable-token` flag.
* Config properties file: use `disable_token_authorization` and set to `true` to disable and `false` to enable token authorization.

Priority between env variables, cmd, and config file follows the following [TorchServer standard](https://github.com/pytorch/serve/blob/c74a29e8144bc12b84196775076b0e8cf3c5a6fc/docs/configuration.md#advanced-configuration)
Priority between env variables, cmd, and config file follows the following [TorchServer standard](https://github.com/pytorch/serve/blob/master/docs/configuration.md)

* Example 1:
* Config file: `disable_token_authorization=false`

Expand Down Expand Up @@ -48,7 +50,7 @@ Priority between env variables, cmd, and config file follows the following [Torc
2. Inference key: Used for inference APIs. Example:
`curl http://127.0.0.1:8080/predictions/densenet161 -T examples/image_classifier/kitten.jpg -H "Authorization: Bearer FINhR1fj"`
3. API key: Used for the token authorization API. Check section 4 for API use.
4. The plugin also includes an API in order to generate a new key to replace either the management or inference key.
4. API in order to generate a new key to replace either the management or inference key.
1. Management Example:
`curl localhost:8081/token?type=management -H "Authorization: Bearer m4M-5IBY"` will replace the current management key in the key_file with a new one and will update the expiration time.
2. Inference example:
Expand All @@ -61,4 +63,4 @@ Priority between env variables, cmd, and config file follows the following [Torc
## Notes
1. DO NOT MODIFY THE KEY FILE. Modifying the key file might impact reading and writing to the file thus preventing new keys from properly being displayed in the file.
2. Time to expiration is set to default at 60 minutes but can be changed in the config.properties by adding `token_expiration_min`. Ex:`token_expiration_min=30`
3. 3 tokens allow the owner with the most flexibility in use and enables them to adapt the tokens to their use. Owners of the server can provide users with the inference token if users should only be able to run inferences against models that have already been loaded. The owner can also provide owners with the management key if owners want users to add and remove models.
3. Three tokens allow the owner with the most flexibility in use and enables them to adapt the tokens to their use. Owners of the server can provide users with the inference token if users should only be able to run inferences against models that have already been loaded. The owner can also provide owners with the management key if owners want users to add and remove models.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.pytorch.serve.archive.model.ModelNotFoundException;
import org.pytorch.serve.grpcimpl.GRPCInterceptor;
import org.pytorch.serve.grpcimpl.GRPCServiceFactory;
import org.pytorch.serve.http.TokenAuthorizationHandler;
import org.pytorch.serve.http.messages.RegisterModelRequest;
import org.pytorch.serve.metrics.MetricCache;
import org.pytorch.serve.metrics.MetricManager;
Expand Down Expand Up @@ -86,7 +87,7 @@ public static void main(String[] args) {
ConfigManager.Arguments arguments = new ConfigManager.Arguments(cmd);
ConfigManager.init(arguments);
ConfigManager configManager = ConfigManager.getInstance();
configManager.setupToken();
TokenAuthorizationHandler.setupToken();
PluginsManager.getInstance().initialize();
MetricCache.init();
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ public class TokenAuthorizationHandler extends HttpRequestHandlerChain {
private static final Logger logger = LoggerFactory.getLogger(TokenAuthorizationHandler.class);
private static TokenType tokenType;
private static Boolean tokenEnabled = false;
private static Token tokenClass;
private static Token token;
private static Object tokenObject;
private static Double timeToExpirationMinutes = 60.0;

/** Creates a new {@code InferenceRequestHandler} instance. */
public TokenAuthorizationHandler(TokenType type) {
Expand All @@ -62,11 +61,12 @@ public void handleRequest(
if (req.toString().contains("/token")) {
try {
checkTokenAuthorization(req, "token");
String resp = tokenClass.updateKeyFile(req);
String queryResponse = parseQuery(req);
String resp = token.updateKeyFile(queryResponse);
NettyUtils.sendJsonResponse(ctx, resp);
return;
} catch (Exception e) {
logger.error("TOKEN CLASS UPDATED UNSUCCESSFULLY");
logger.error("Key file updated unsuccessfully");
throw new InvalidKeyException(
"Token Authentication failed. Token either incorrect, expired, or not provided correctly");
}
Expand All @@ -76,48 +76,60 @@ public void handleRequest(
} else if (tokenType == TokenType.INFERENCE) {
checkTokenAuthorization(req, "inference");
}
} else {
if (tokenType == TokenType.MANAGEMENT && req.toString().contains("/token")) {
throw new ResourceNotFoundException();
}
}
chain.handleRequest(ctx, req, decoder, segments);
}

public static void setupTokenClass() {
try {
tokenClass = new Token();
Double time = ConfigManager.getInstance().getTimeToExpiration();
String home = ConfigManager.getInstance().getModelServerHome();
tokenClass.setFilePath(home);
if (time != 0.0) {
timeToExpirationMinutes = time;
}
tokenClass.setTime(timeToExpirationMinutes);
if (tokenClass.generateKeyFile("token")) {
logger.info("Token Authorization Enabled");
public static void setupToken() {
if (!ConfigManager.getInstance().getDisableTokenAuthorization()) {
try {
token = new Token();
if (token.generateKeyFile("token")) {
logger.info("Token Authorization Enabled");
}
} catch (IOException e) {
e.printStackTrace();
logger.error("Token Authorization setup unsuccessfully");
throw new IllegalStateException("Token Authorization setup unsuccessfully", e);
}
} catch (Exception e) {
e.printStackTrace();
logger.error("TOKEN CLASS IMPORTED UNSUCCESSFULLY");
throw new IllegalStateException("Unable to import token class", e);
tokenEnabled = true;
}
tokenEnabled = true;
}

private void checkTokenAuthorization(FullHttpRequest req, String type) throws ModelException {
String tokenBearer = req.headers().get("Authorization");
if (tokenBearer == null) {
throw new InvalidKeyException(
"Token Authentication failed. Token either incorrect, expired, or not provided correctly");
}
String[] arrOfStr = tokenBearer.split(" ", 2);
if (arrOfStr.length == 1) {
throw new InvalidKeyException(
"Token Authentication failed. Token either incorrect, expired, or not provided correctly");
}
String currToken = arrOfStr[1];

try {
boolean result = tokenClass.checkTokenAuthorization(req, type);
if (!result) {
throw new InvalidKeyException(
"Token Authentication failed. Token either incorrect, expired, or not provided correctly");
}
} catch (Exception e) {
boolean result = token.checkTokenAuthorization(currToken, type);
if (!result) {
throw new InvalidKeyException(
"Token Authentication failed. Token either incorrect, expired, or not provided correctly");
}
}

// parses query and either returns management/inference or a wrong type error
private String parseQuery(FullHttpRequest req) {
QueryStringDecoder decoder = new QueryStringDecoder(req.uri());
Map<String, List<String>> parameters = decoder.parameters();
List<String> values = parameters.get("type");
if (values != null && !values.isEmpty()) {
if ("management".equals(values.get(0)) || "inference".equals(values.get(0))) {
return values.get(0);
} else {
return "WRONG TYPE";
}
}
return "NO TYPE PROVIDED";
}
}

class Token {
Expand All @@ -126,14 +138,12 @@ class Token {
private static String inferenceKey;
private static Instant managementExpirationTimeMinutes;
private static Instant inferenceExpirationTimeMinutes;
private static Double timeToExpirationMinutes;
private SecureRandom secureRandom = new SecureRandom();
private Base64.Encoder baseEncoder = Base64.getUrlEncoder();
private String fileName = "key_file.json";
private String filePath = "";
private String filePath = ConfigManager.getInstance().getModelServerHome();

public String updateKeyFile(FullHttpRequest req) throws IOException {
String queryResponse = parseQuery(req);
public String updateKeyFile(String queryResponse) throws IOException {
String test = "";
if ("management".equals(queryResponse)) {
generateKeyFile("management");
Expand All @@ -145,36 +155,17 @@ public String updateKeyFile(FullHttpRequest req) throws IOException {
return test;
}

// parses query and either returns management/inference or a wrong type error
public String parseQuery(FullHttpRequest req) {
QueryStringDecoder decoder = new QueryStringDecoder(req.uri());
Map<String, List<String>> parameters = decoder.parameters();
List<String> values = parameters.get("type");
if (values != null && !values.isEmpty()) {
if ("management".equals(values.get(0)) || "inference".equals(values.get(0))) {
return values.get(0);
} else {
return "WRONG TYPE";
}
}
return "NO TYPE PROVIDED";
}

public String generateKey() {
byte[] randomBytes = new byte[6];
secureRandom.nextBytes(randomBytes);
return baseEncoder.encodeToString(randomBytes);
}

public Instant generateTokenExpiration() {
long secondsToAdd = (long) (timeToExpirationMinutes * 60);
long secondsToAdd = (long) (ConfigManager.getInstance().getTimeToExpiration() * 60);
return Instant.now().plusSeconds(secondsToAdd);
}

public void setFilePath(String path) {
filePath = path;
}

// generates a key file with new keys depending on the parameter provided
public boolean generateKeyFile(String type) throws IOException {
String userDirectory = filePath + "/" + fileName;
Expand Down Expand Up @@ -248,7 +239,7 @@ public boolean setFilePermissions() {
}

// checks the token provided in the http with the saved keys depening on parameters
public boolean checkTokenAuthorization(FullHttpRequest req, String type) {
public boolean checkTokenAuthorization(String token, String type) {
String key;
Instant expiration;
switch (type) {
Expand All @@ -265,16 +256,6 @@ public boolean checkTokenAuthorization(FullHttpRequest req, String type) {
expiration = inferenceExpirationTimeMinutes;
}

String tokenBearer = req.headers().get("Authorization");
if (tokenBearer == null) {
return false;
}
String[] arrOfStr = tokenBearer.split(" ", 2);
if (arrOfStr.length == 1) {
return false;
}
String token = arrOfStr[1];

if (token.equals(key)) {
if (expiration != null && isTokenExpired(expiration)) {
return false;
Expand All @@ -288,28 +269,4 @@ public boolean checkTokenAuthorization(FullHttpRequest req, String type) {
public boolean isTokenExpired(Instant expirationTime) {
return !(Instant.now().isBefore(expirationTime));
}

public String getManagementKey() {
return managementKey;
}

public String getInferenceKey() {
return inferenceKey;
}

public String getKey() {
return apiKey;
}

public Instant getInferenceExpirationTime() {
return inferenceExpirationTimeMinutes;
}

public Instant getManagementExpirationTime() {
return managementExpirationTimeMinutes;
}

public void setTime(Double time) {
timeToExpirationMinutes = time;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import org.apache.commons.cli.Options;
import org.apache.commons.io.IOUtils;
import org.pytorch.serve.archive.model.Manifest;
import org.pytorch.serve.http.TokenAuthorizationHandler;
import org.pytorch.serve.metrics.MetricBuilder;
import org.pytorch.serve.servingsdk.snapshot.SnapshotSerializer;
import org.pytorch.serve.snapshot.SnapshotSerializerFactory;
Expand Down Expand Up @@ -450,14 +449,6 @@ public boolean isOpenInferenceProtocol() {
return Boolean.parseBoolean(prop.getProperty(TS_OPEN_INFERENCE_PROTOCOL, "false"));
}

public boolean setupToken() {
boolean disable_token_authorization = getDisableTokenAuthorization();
if (!disable_token_authorization) {
TokenAuthorizationHandler.setupTokenClass();
}
return true;
}

public boolean isGRPCSSLEnabled() {
return Boolean.parseBoolean(getProperty(TS_ENABLE_GRPC_SSL, "false"));
}
Expand Down Expand Up @@ -1001,7 +992,7 @@ public Double getTimeToExpiration() {
logger.error("Token expiration not a valid integer");
}
}
return 0.0;
return 60.0;
}

public String getTsHeaderKeySequenceId() {
Expand Down
Loading

0 comments on commit f918cd1

Please sign in to comment.