Skip to content

Commit

Permalink
WSRequest: Normalize URL
Browse files Browse the repository at this point in the history
  • Loading branch information
htmldoug committed Jun 3, 2019
1 parent d42f158 commit dbcc1c1
Show file tree
Hide file tree
Showing 6 changed files with 284 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright (C) 2009-2019 Lightbend Inc. <https://www.lightbend.com>
*/

package play.api.libs.ws.ahc

import java.util.concurrent.TimeUnit

import akka.stream.Materializer
import org.openjdk.jmh.annotations._
import org.openjdk.jmh.infra.Blackhole

/**
* ==Quick Run from sbt==
*
* > bench/jmh:run .*StandaloneAhcWSRequestBench
*
* ==Using Oracle Flight Recorder==
*
* To record a Flight Recorder file from a JMH run, run it using the jmh.extras.JFR profiler:
* > bench/jmh:run -prof jmh.extras.JFR .*StandaloneAhcWSRequestBench
*
* Compare your results before/after on your machine. Don't trust the ones in scaladoc.
*
* Sample benchmark results:
* {{{
* > bench/jmh:run .*StandaloneAhcWSRequestBench
* [info] Benchmark Mode Cnt Score Error Units
* [info] StandaloneAhcWSRequestBench.urlNoParams avgt 5 326.443 ± 3.712 ns/op
* [info] StandaloneAhcWSRequestBench.urlWithParams avgt 5 1562.871 ± 16.736 ns/op
* }}}
*
* @see https://github.com/ktoso/sbt-jmh
*/
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@BenchmarkMode(Array(Mode.AverageTime))
@Fork(jvmArgsAppend = Array("-Xmx350m", "-XX:+HeapDumpOnOutOfMemoryError"), value = 1)
@State(Scope.Benchmark)
class StandaloneAhcWSRequestBench {

private implicit val materializer: Materializer = null // we're not actually going to execute anything.
private val wsClient = StandaloneAhcWSClient()

@Benchmark
def urlNoParams(bh: Blackhole): Unit = {
bh.consume(wsClient.url("https://www.example.com/foo/bar/a/b"))
}

@Benchmark
def urlWithParams(bh: Blackhole): Unit = {
bh.consume(wsClient.url("https://www.example.com?foo=bar& = "))
}

@TearDown
def teardown(): Unit = wsClient.close()
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

package play.api.libs.ws.ahc

import java.net.URLDecoder
import java.util.Collections

import akka.Done
import javax.inject.Inject
import akka.stream.Materializer
Expand All @@ -22,13 +25,15 @@ import play.shaded.ahc.org.asynchttpclient.{ Response => AHCResponse }
import play.shaded.ahc.org.asynchttpclient._
import java.util.function.{ Function => JFunction }

import play.shaded.ahc.org.asynchttpclient.util.UriEncoder

import scala.collection.immutable.TreeMap
import scala.compat.java8.FunctionConverters._
import scala.concurrent.Await
import scala.concurrent.Future
import scala.concurrent.Promise
import scala.util.Failure
import scala.util.Success
import scala.util.control.NonFatal
import scala.util.{ Failure, Success, Try }

/**
* A WS client backed by an AsyncHttpClient.
Expand All @@ -53,8 +58,7 @@ class StandaloneAhcWSClient @Inject() (asyncHttpClient: AsyncHttpClient)(
}

def url(url: String): StandaloneWSRequest = {
validate(url)
StandaloneAhcWSRequest(
val req = StandaloneAhcWSRequest(
client = this,
url = url,
method = "GET",
Expand All @@ -70,6 +74,8 @@ class StandaloneAhcWSClient @Inject() (asyncHttpClient: AsyncHttpClient)(
proxyServer = None,
disableUrlEncoding = None
)

StandaloneAhcWSClient.normalize(req)
}

private[ahc] def execute(
Expand All @@ -91,18 +97,6 @@ class StandaloneAhcWSClient @Inject() (asyncHttpClient: AsyncHttpClient)(
result.future
}

private def validate(url: String): Unit = {
// Recover from https://github.com/AsyncHttpClient/async-http-client/issues/1149
try {
Uri.create(url)
} catch {
case iae: IllegalArgumentException =>
throw new IllegalArgumentException(s"Invalid URL $url", iae)
case npe: NullPointerException =>
throw new IllegalArgumentException(s"Invalid URL $url", npe)
}
}

private[ahc] def executeStream(request: Request): Future[StreamedResponse] = {
val streamStarted = Promise[StreamedResponse]()
val streamCompletion = Promise[Done]()
Expand Down Expand Up @@ -224,4 +218,132 @@ object StandaloneAhcWSClient {
new SystemConfiguration(loggerFactory).configure(config.wsClientConfig.ssl)
wsClient
}

/**
* Ensures:
*
* 1. [[StandaloneWSRequest.url]] path is encoded, e.g.
* ws.url("http://example.com/foo bar") ->
* ws.url("http://example.com/foo%20bar")
*
* 2. Any query params present in the URL are moved to [[StandaloneWSRequest.queryString]], e.g.
* ws.url("http://example.com/?foo=bar") ->
* ws.url("http://example.com/").withQueryString("foo" -> "bar")
*/
@throws[IllegalArgumentException]("if the url is unrepairable")
private[ahc] def normalize(req: StandaloneAhcWSRequest): StandaloneWSRequest = {
import play.shaded.ahc.org.asynchttpclient.util.MiscUtils.isEmpty
if (req.url.indexOf('?') != -1) {
// Query params in the path. Move them to the queryParams: Map.
repair(req)
} else {
Try(req.uri) match {
case Success(uri) =>

/*
* [[Uri.create()]] throws if the host or scheme is missing.
* We can do those checks against the the [[java.net.URI]]
* to avoid incurring the cost of re-parsing the URL string.
*
* @see https://github.com/AsyncHttpClient/async-http-client/issues/1149
*/
if (isEmpty(uri.getScheme)) {
throw new IllegalArgumentException(req.url + " could not be parsed into a proper Uri, missing scheme")
}
if (isEmpty(uri.getHost)) {
throw new IllegalArgumentException(req.url + " could not be parsed into a proper Uri, missing host")
}

req
case Failure(_) =>
// URI parsing error. Sometimes recoverable by UriEncoder.FIXING
repair(req)
}
}
}

/**
* Encodes the URI to [[Uri]] and runs it through the same [[UriEncoder.FIXING]]
* that async-http-client uses before executing it.
*/
@throws[IllegalArgumentException]("if the url is unrepairable")
private def repair(req: StandaloneAhcWSRequest): StandaloneWSRequest = {
try {
val encodedAhcUri: Uri = toUri(req)
val javaUri = encodedAhcUri.toJavaNetURI
setUri(req, encodedAhcUri.withNewQuery(null).toUrl, Option(javaUri.getRawQuery))
} catch {
case NonFatal(t) =>
throw new IllegalArgumentException(s"Invalid URL ${req.url}", t)
}
}

/**
* Builds an AHC [[Uri]] with all parts URL encoded by [[UriEncoder.FIXING]].
* Combines query params from both [[StandaloneWSRequest.url]] and [[StandaloneWSRequest.queryString]].
*/
private def toUri(req: StandaloneWSRequest): Uri = {
val combinedUri: Uri = {
val uri = Uri.create(req.url)

val paramsMap = req.queryString
if (paramsMap.nonEmpty) {
val query: String = combineQuery(uri.getQuery, paramsMap)
uri.withNewQuery(query)
} else {
uri
}
}

// FIXING.encode() encodes ONLY unencoded parts, leaving encoded parts untouched.
UriEncoder.FIXING.encode(combinedUri, Collections.emptyList())
}

private def combineQuery(query: String, params: Map[String, Seq[String]]): String = {
val sb = new StringBuilder
// Reminder: ahc.Uri.query does include '?' (unlike java.net.URI)
if (query != null) {
sb.append(query)
}

for {
(key, values) <- params
value <- values
} {
if (sb.nonEmpty) {
sb.append('&')
}
sb.append(key)
if (value.nonEmpty) {
sb.append('=').append(value)
}
}

sb.toString
}

/**
* Replace the [[StandaloneWSRequest.url]] and [[StandaloneWSRequest.queryString]]
* with the values of [[uri]], discarding originals.
*/
private def setUri(
req: StandaloneAhcWSRequest,
urlNoQueryParams: String,
encodedQueryString: Option[String]): StandaloneWSRequest = {
val queryParams: List[(String, String)] = for {
queryString <- encodedQueryString.toList
// https://stackoverflow.com/a/13592567 for all of this.
pair <- queryString.split('&')
idx = pair.indexOf('=')
key = URLDecoder.decode(if (idx > 0) pair.substring(0, idx) else pair, "UTF-8")
value = if (idx > 0) URLDecoder.decode(pair.substring(idx + 1), "UTF-8") else ""
} yield key -> value

req
// Intentionally using copy(url) instead of withUrl(url) to avoid
// withUrl() -> normalize() -> withUrl() -> normalize()
// just in case we missed a case.
.copy(url = urlNoQueryParams)(req.materializer)
.withQueryStringParameters(queryParams: _*)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import java.nio.charset.{ Charset, StandardCharsets }

import akka.stream.Materializer
import akka.stream.scaladsl.Sink
import play.api.libs.ws.{ StandaloneWSRequest, _ }
import play.api.libs.ws._
import play.shaded.ahc.io.netty.buffer.Unpooled
import play.shaded.ahc.io.netty.handler.codec.http.HttpHeaders
import play.shaded.ahc.org.asynchttpclient.Realm.AuthScheme
Expand Down Expand Up @@ -42,7 +42,7 @@ case class StandaloneAhcWSRequest(
proxyServer: Option[WSProxyServer] = None,
disableUrlEncoding: Option[Boolean] = None,
private val filters: Seq[WSRequestFilter] = Nil
)(implicit materializer: Materializer) extends StandaloneWSRequest with AhcUtilities with WSCookieConverter {
)(implicit private[ahc] val materializer: Materializer) extends StandaloneWSRequest with AhcUtilities with WSCookieConverter {
override type Self = StandaloneWSRequest
override type Response = StandaloneWSResponse

Expand Down Expand Up @@ -207,7 +207,10 @@ case class StandaloneAhcWSRequest(
withMethod(method).execute()
}

override def withUrl(url: String): Self = copy(url = url)
override def withUrl(url: String): Self = {
val unsafe = copy(url = url)
StandaloneAhcWSClient.normalize(unsafe)
}

override def withMethod(method: String): Self = copy(method = method)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

package play.api.libs.ws.ahc

import java.net.URLEncoder

import akka.actor.ActorSystem
import akka.stream.ActorMaterializer
import akka.util.ByteString
Expand All @@ -16,6 +18,7 @@ import play.api.libs.ws._
import play.shaded.ahc.io.netty.handler.codec.http.HttpHeaderNames
import play.shaded.ahc.org.asynchttpclient.Realm.AuthScheme
import play.shaded.ahc.org.asynchttpclient.{ Param, SignatureCalculator, Request => AHCRequest }

import scala.collection.JavaConverters._
import scala.concurrent.duration._

Expand Down Expand Up @@ -103,9 +106,84 @@ class AhcWSRequestSpec extends Specification with Mockito with AfterAll with Def
paramsList.exists(p => (p.getName == "foo") && (p.getValue == "foo2")) must beTrue
paramsList.count(p => p.getName == "foo") must beEqualTo(2)
}
}

"with unencoded values" in {

/**
* All tests in this block produce the same request.
*/
def verify(input: StandaloneWSClient => StandaloneWSRequest) = {
withClient { client =>
val request = input(client)

val uri = request.uri
uri.getPath === "/|"
uri.getQuery.split('&').toSeq must contain(exactly("!=", "#=$", "^=*", "^=("))

request.url === "http://www.example.com/%7C"
request.queryString must contain("!" -> Seq(""))
request.queryString must contain("#" -> Seq("$"))
request.queryString.get("^") must beSome.which(_ must contain(exactly("*", "(")))
}
}

"path=plain qp=plain" in verify { client =>
client
.url("http://www.example.com/|?!")
.addQueryStringParameters("#" -> "$")
.addQueryStringParameters("^" -> "*", "^" -> "(")
}

"path=enc qp=plain" in verify { client =>
client
.url("http://www.example.com/|?%21")
.addQueryStringParameters("#" -> "$")
.addQueryStringParameters("^" -> "*", "^" -> "(")
}

"path=enc qp=enc" in verify { client =>
client
.url("http://www.example.com/%7C?%21")
.addQueryStringParameters("#" -> "$")
.addQueryStringParameters("^" -> "*", "^" -> "(")
}

"withUrl" in verify { client =>
client
.url("http://www.example.com/%7C")
.addQueryStringParameters("#" -> "$")
.addQueryStringParameters("^" -> "*", "^" -> "(")
.withUrl("http://www.example.com/|?!")
}
}

"with encoded query params" in {
def testEncoded(unencoded: String) = withClient { client =>
val encoded = URLEncoder.encode(unencoded, "UTF-8")
val request = client.url(s"http://www.example.com/?$encoded=$encoded")
request.url === "http://www.example.com/"
request.queryString === Map(unencoded -> Seq(unencoded))
}

"=" in testEncoded("=")
"?" in testEncoded("?")
"/" in testEncoded("/")
"+" in testEncoded("+")
" " in testEncoded(" ")
}

"with urls that normalize can't fix" in {
"kitty" in withClient { client =>
client.url(">^..^<") must throwA[IllegalArgumentException]
}

"withUrl kitty" in withClient { client =>
val validRequest = client.url("https://www.example.com")
validRequest.withUrl(">^..^<") must throwA[IllegalArgumentException]
}

}
}

"For Cookies" in {
Expand Down Expand Up @@ -571,5 +649,4 @@ class AhcWSRequestSpec extends Specification with Mockito with AfterAll with Def
.buildRequest()
req.getHeaders.getAll(HttpHeaderNames.CONTENT_TYPE.toString()).asScala must_== Seq("text/plain; charset=US-ASCII")
}

}
Loading

0 comments on commit dbcc1c1

Please sign in to comment.