Skip to content

Commit

Permalink
scrooge-generator: Java Api to Validate requests when servers receive…
Browse files Browse the repository at this point in the history
… them

Problem

For services defined in the IDL, we want to validate all
input parameters for all methods defined in the service.
If the parameter is of type struct, union, or exception type,
we want to leverage the validateInstanceValue API.
After we validated all parameters, we will throw a
ThriftValidationException for all validation violations
for the Java version

Solution

Modified Service.mustache template for Java generator
to throw the exception in ServiceIface definitions

JIRA Issues: CSL-11200

Differential Revision: https://phabricator.twitter.biz/D792676
  • Loading branch information
heligw authored and jenkins committed Dec 14, 2021
1 parent a981322 commit 6bc2a24
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,23 @@
package com.twitter.scrooge.thrift_validation
import scala.collection.JavaConverters._

object ThriftValidationException {

/**
* A java compatible exception which is used to communicate when a thrift_validation
* has failed with the respective [[ThriftValidationViolation]]
*
* @param endpoint thrift method the invalid request tries to reach
* @param requestClazz the type of request that was passed in
* @param validationViolations all violations collected while deserializing the thrift object
*/
def create(
endpoint: String,
requestClazz: Class[_],
validationViolations: java.util.Set[ThriftValidationViolation]
): ThriftValidationException =
ThriftValidationException(endpoint, requestClazz, validationViolations.asScala.toSet)
}

/**
* An exception which is used to communicate when a thrift_validation
Expand All @@ -12,7 +31,7 @@ final case class ThriftValidationException(
endpoint: String,
requestClazz: Class[_],
validationViolations: Set[ThriftValidationViolation])
extends Exception {
extends RuntimeException {

override def getMessage: String =
s" The validation for request ${requestClazz.getName} to endpoint $endpoint failed with messages: ${validationViolations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,13 @@ public Future<byte[]> apply(Throwable t) {
private final com.twitter.finagle.Service<doGreatThings_args, Response> methodService = new com.twitter.finagle.Service<doGreatThings_args, Response>() {
@Override
public Future<Response> apply(doGreatThings_args args) {
try {
Set<ThriftValidationViolation> requestViolations = Request.validateInstanceValue(args.request);
if (!requestViolations.isEmpty()) {
throw com.twitter.scrooge.thrift_validation.ThriftValidationException.create("doGreatThings", args.request.getClass(), requestViolations);
}
} catch(NullPointerException e) {
}
Future<Response> future = iface.doGreatThings(args.request);
return future;
}
Expand Down Expand Up @@ -697,6 +704,13 @@ public Future<byte[]> apply(Throwable t) {
private final com.twitter.finagle.Service<noExceptionCall_args, Response> methodService = new com.twitter.finagle.Service<noExceptionCall_args, Response>() {
@Override
public Future<Response> apply(noExceptionCall_args args) {
try {
Set<ThriftValidationViolation> requestViolations = Request.validateInstanceValue(args.request);
if (!requestViolations.isEmpty()) {
throw com.twitter.scrooge.thrift_validation.ThriftValidationException.create("noExceptionCall", args.request.getClass(), requestViolations);
}
} catch(NullPointerException e) {
}
Future<Response> future = iface.noExceptionCall(args.request);
return future;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,13 @@ else if (t instanceof OverCapacityException) {
private final com.twitter.finagle.Service<moreCoolThings_args, Integer> methodService = new com.twitter.finagle.Service<moreCoolThings_args, Integer>() {
@Override
public Future<Integer> apply(moreCoolThings_args args) {
try {
Set<ThriftValidationViolation> requestViolations = Request.validateInstanceValue(args.request);
if (!requestViolations.isEmpty()) {
throw com.twitter.scrooge.thrift_validation.ThriftValidationException.create("moreCoolThings", args.request.getClass(), requestViolations);
}
} catch(NullPointerException e) {
}
Future<Integer> future = iface.moreCoolThings(args.request);
return future;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,39 @@
package com.twitter.scrooge.java_generator

import apache_java_thrift._
import com.twitter.conversions.DurationOps.richDurationFromInt
import com.twitter.finagle.Address
import com.twitter.finagle.Name
import com.twitter.finagle.ThriftMux
import com.twitter.scrooge.testutil.Spec
import com.twitter.scrooge.thrift_validation.ThriftValidationViolation
import com.twitter.util.Await
import com.twitter.util.Awaitable
import com.twitter.util.Duration
import com.twitter.util.Future
import java.lang
import java.net.InetSocketAddress
import org.apache.thrift.TApplicationException
import scala.jdk.CollectionConverters._

class ValidationsJavaGeneratorSpec extends Spec {
def await[T](a: Awaitable[T], d: Duration = 5.seconds): T =
Await.result(a, d)

private class ValidationServiceImpl extends ValidationService.ServiceIface {
override def validate(
structRequest: ValidationStruct,
unionRequest: ValidationUnion,
exceptionRequest: ValidationException
): Future[lang.Boolean] = Future.value(true)

override def validateOption(
structRequest: ValidationStruct,
unionRequest: ValidationUnion,
exceptionRequest: ValidationException
): Future[lang.Boolean] = Future.value(true)
}

"Java validateInstanceValue" should {
"validate Struct" in {
val validationStruct =
Expand Down Expand Up @@ -87,6 +115,40 @@ class ValidationsJavaGeneratorSpec extends Spec {
val validationViolations = NonValidationStruct.validateInstanceValue(nonValidationStruct)
assertViolations(validationViolations.asScala.toSet, 0, Set.empty)
}

"validate struct, union and exception request" in {
val validationStruct = new ValidationStruct(
"email",
-1,
101,
0,
0,
Map("1" -> "1", "2" -> "2").asJava,
false,
"anything")
val impl = new ValidationServiceImpl()
val validationIntUnion = new ValidationUnion()
validationIntUnion.setIntField(-1)
val validationException = new ValidationException("")
val muxServer = ThriftMux.server.serveIface("localhost:*", impl)
val muxClient = ThriftMux.client.build[ValidationService.ServiceIface](
Name.bound(Address(muxServer.boundAddress.asInstanceOf[InetSocketAddress])),
"a_client")
intercept[TApplicationException] {
await(muxClient.validate(validationStruct, validationIntUnion, validationException))
}
}

"validate null request" in {
val impl = new ValidationServiceImpl()
val muxServer = ThriftMux.server.serveIface("localhost:*", impl)
val muxClient = ThriftMux.client.build[ValidationService.ServiceIface](
Name.bound(Address(muxServer.boundAddress.asInstanceOf[InetSocketAddress])),
"a_client")
//null values passed after code generation aren't checked so
// we catch NullPointerException in the mustache file
await(muxClient.validate(null, null, null))
}
}

private def assertViolations(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,19 @@ public class {{name}} {
private final com.twitter.finagle.Service<{{name}}_args, {{{return_type.type_name_in_container}}}> methodService = new com.twitter.finagle.Service<{{name}}_args, {{{return_type.type_name_in_container}}}>() {
@Override
public Future<{{{return_type.type_name_in_container}}}> apply({{name}}_args args) {
{{#has_args}}
{{#fields}}
{{#field_type.is_struct}}
try {
Set<ThriftValidationViolation> {{field_name}}Violations = {{arg_type}}.validateInstanceValue({{field_arg}});
if (!{{field_name}}Violations.isEmpty()) {
throw com.twitter.scrooge.thrift_validation.ThriftValidationException.create("{{name}}", {{field_arg}}.getClass(), {{field_name}}Violations);
}
} catch(NullPointerException e) {
}
{{/field_type.is_struct}}
{{/fields}}
{{/has_args}}
Future<{{{return_type.type_name_in_container}}}> future = iface.{{name}}({{{argument_list_with_args}}});
return future;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package com.twitter.scrooge.java_generator

import com.twitter.scrooge.ast.{Field, Identifier, Requiredness}
import com.twitter.scrooge.ast.Field
import com.twitter.scrooge.ast.Identifier
import com.twitter.scrooge.ast.Requiredness
import com.twitter.scrooge.backend.Generator
import com.google.common.base

Expand All @@ -15,6 +17,8 @@ class FieldController(f: Field, generator: ApacheJavaGenerator, ns: Option[Ident
val has_annotations: Boolean = f.fieldAnnotations.nonEmpty

val field_type: FieldTypeController = new FieldTypeController(f.fieldType, generator)
val field_arg: String = "args." + f.sid.name
val arg_type: String = generator.typeName(f.fieldType)

def getRequirement(field: Field): String = {
field.requiredness match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class FunctionController(
val argument_list_with_args: String = function.args map { a =>
"args." + a.sid.name
} mkString ", "

val has_args: Boolean = function.args.size > 0
val fields: Seq[FieldController] = function.args map { a =>
new FieldController(a, generator, ns)
Expand Down

0 comments on commit 6bc2a24

Please sign in to comment.