From 587c5db846c780b29dcfb9e19e670b18d1586db2 Mon Sep 17 00:00:00 2001 From: Doug Roper Date: Sun, 28 Oct 2018 22:14:11 -0500 Subject: [PATCH] WSRequest: Normalize URL --- .../libs/ws/ahc/StandaloneAhcWSClient.scala | 126 +++++++++++++++--- .../libs/ws/ahc/StandaloneAhcWSRequest.scala | 7 +- .../api/libs/ws/ahc/AhcWSRequestSpec.scala | 44 +++++- .../play/api/libs/ws/StandaloneWSClient.scala | 6 +- 4 files changed, 162 insertions(+), 21 deletions(-) diff --git a/play-ahc-ws-standalone/src/main/scala/play/api/libs/ws/ahc/StandaloneAhcWSClient.scala b/play-ahc-ws-standalone/src/main/scala/play/api/libs/ws/ahc/StandaloneAhcWSClient.scala index 995ef162..f6d19af7 100644 --- a/play-ahc-ws-standalone/src/main/scala/play/api/libs/ws/ahc/StandaloneAhcWSClient.scala +++ b/play-ahc-ws-standalone/src/main/scala/play/api/libs/ws/ahc/StandaloneAhcWSClient.scala @@ -3,21 +3,25 @@ */ package play.api.libs.ws.ahc -import javax.inject.Inject +import java.net.URLDecoder +import java.util.Collections import akka.stream.Materializer import akka.stream.scaladsl.Source import akka.util.ByteString import com.typesafe.sslconfig.ssl.SystemConfiguration import com.typesafe.sslconfig.ssl.debug.DebugConfiguration +import javax.inject.Inject import play.api.libs.ws.ahc.cache._ import play.api.libs.ws.{ EmptyBody, StandaloneWSClient, StandaloneWSRequest } import play.shaded.ahc.org.asynchttpclient.uri.Uri +import play.shaded.ahc.org.asynchttpclient.util.{ StringBuilderPool, UriEncoder } import play.shaded.ahc.org.asynchttpclient.{ Response => AHCResponse, _ } import scala.collection.immutable.TreeMap import scala.compat.java8.FunctionConverters import scala.concurrent.{ Await, Future, Promise } +import scala.util.control.NonFatal /** * A WS client backed by an AsyncHttpClient. @@ -39,8 +43,7 @@ class StandaloneAhcWSClient @Inject() (asyncHttpClient: AsyncHttpClient)(implici } def url(url: String): StandaloneWSRequest = { - validate(url) - StandaloneAhcWSRequest( + val req = StandaloneAhcWSRequest( client = this, url = url, method = "GET", @@ -56,6 +59,8 @@ class StandaloneAhcWSClient @Inject() (asyncHttpClient: AsyncHttpClient)(implici proxyServer = None, disableUrlEncoding = None ) + + StandaloneAhcWSClient.normalize(req) } private[ahc] def execute(request: Request): Future[StandaloneAhcWSResponse] = { @@ -75,18 +80,6 @@ class StandaloneAhcWSClient @Inject() (asyncHttpClient: AsyncHttpClient)(implici result.future } - private def validate(url: String): Unit = { - // Recover from https://github.com/AsyncHttpClient/async-http-client/issues/1149 - try { - Uri.create(url) - } catch { - case iae: IllegalArgumentException => - throw new IllegalArgumentException(s"Invalid URL $url", iae) - case npe: NullPointerException => - throw new IllegalArgumentException(s"Invalid URL $url", npe) - } - } - private[ahc] def executeStream(request: Request): Future[StreamedResponse] = { val promise = Promise[StreamedResponse]() @@ -116,12 +109,12 @@ class StandaloneAhcWSClient @Inject() (asyncHttpClient: AsyncHttpClient)(implici Await.result(result, StandaloneAhcWSClient.blockingTimeout) } - } object StandaloneAhcWSClient { import scala.concurrent.duration._ + val blockingTimeout = 50.milliseconds val elementLimit = 13 // 13 8192k blocks is roughly 100k private val logger = org.slf4j.LoggerFactory.getLogger(this.getClass) @@ -163,5 +156,106 @@ object StandaloneAhcWSClient { new SystemConfiguration(loggerFactory).configure(config.wsClientConfig.ssl) wsClient } + + /** + * Ensures: + * 1. [[StandaloneWSRequest.url]] path is encoded. + * 2. Any query params present in the URL are moved to [[StandaloneWSRequest.queryString]]. + */ + @throws[IllegalArgumentException]("if the url is unrepairable") + private[ahc] def normalize(req: StandaloneWSRequest): StandaloneWSRequest = { + try { + // Recover from https://github.com/AsyncHttpClient/async-http-client/issues/1149 + Uri.create(req.url) + if (req.uri.getQuery == null) { + // happy path + req + } else { + // valid, but move query params into the Map + repair(req) + } + } catch { + case NonFatal(_) => + // URI parsing error + repair(req) + } + } + + @throws[IllegalArgumentException]("if the url is unrepairable") + private def repair(req: StandaloneWSRequest): StandaloneWSRequest = { + try { + val encodedAhcUri: Uri = toUri(req) + val javaUri = encodedAhcUri.toJavaNetURI + setUri(req, encodedAhcUri.withNewQuery(null).toUrl, Option(javaUri.getQuery)) + } catch { + case NonFatal(t) => + throw new IllegalArgumentException(s"Invalid URL ${req.url}", t) + } + } + + /** + * Builds an AHC [[Uri]] with all parts URL encoded. + * Combines both [[StandaloneWSRequest.url]] and [[StandaloneWSRequest.queryString]]. + */ + private def toUri(req: StandaloneWSRequest): Uri = { + val combinedUri: Uri = { + val uri = Uri.create(req.url) + + val params = req.queryString + if (params.nonEmpty) { + appendParamsToUri(uri, params) + } else { + uri + } + } + + // FIXING.encode() encodes ONLY unencoded parts, leaving encoded parts untouched. + UriEncoder.FIXING.encode(combinedUri, Collections.emptyList()) + } + + /** + * Replace the [[StandaloneWSRequest.url]] and [[StandaloneWSRequest.queryString]] + * with the values of [[uri]], discarding originals. + */ + private def setUri( + req: StandaloneWSRequest, + urlNoQueryParams: String, + encodedQueryString: Option[String]): StandaloneWSRequest = { + val queryParams: List[(String, String)] = for { + queryString <- encodedQueryString.toList + // https://stackoverflow.com/a/13592567 for all of this. + pair <- queryString.split('&') + idx = pair.indexOf('=') + key = if (idx > 0) pair.substring(0, idx) else pair + value = if (idx > 0) URLDecoder.decode(pair.substring(idx + 1)) else "" + } yield key -> value + + req + .withUrl(urlNoQueryParams) + .withQueryStringParameters(queryParams: _*) + } + + private def appendParamsToUri(uri: Uri, params: Map[String, Seq[String]]): Uri = { + val sb = StringBuilderPool.DEFAULT.stringBuilder + // Reminder: ahc.Uri does not start with '?' (unlike java.net.URI) + if (uri.getQuery != null) { + sb.append(uri.getQuery) + } + + for { + (key, values) <- params + value <- values + } { + if (sb.length > 0) { + sb.append('&') + } + sb.append(key) + if (value.nonEmpty) { + sb.append('=').append(value) + } + } + + uri.withNewQuery(sb.toString) + } } diff --git a/play-ahc-ws-standalone/src/main/scala/play/api/libs/ws/ahc/StandaloneAhcWSRequest.scala b/play-ahc-ws-standalone/src/main/scala/play/api/libs/ws/ahc/StandaloneAhcWSRequest.scala index e0985c8f..42514f88 100644 --- a/play-ahc-ws-standalone/src/main/scala/play/api/libs/ws/ahc/StandaloneAhcWSRequest.scala +++ b/play-ahc-ws-standalone/src/main/scala/play/api/libs/ws/ahc/StandaloneAhcWSRequest.scala @@ -9,7 +9,7 @@ import java.nio.charset.{ Charset, StandardCharsets } import akka.stream.Materializer import akka.stream.scaladsl.Sink -import play.api.libs.ws.{ StandaloneWSRequest, _ } +import play.api.libs.ws._ import play.shaded.ahc.io.netty.buffer.Unpooled import play.shaded.ahc.io.netty.handler.codec.http.HttpHeaders import play.shaded.ahc.org.asynchttpclient.Realm.AuthScheme @@ -184,7 +184,10 @@ case class StandaloneAhcWSRequest( withMethod(method).execute() } - override def withUrl(url: String): Self = copy(url = url) + override def withUrl(url: String): Self = { + val unsafe = copy(url = url) + StandaloneAhcWSClient.normalize(unsafe) + } override def withMethod(method: String): Self = copy(method = method) diff --git a/play-ahc-ws-standalone/src/test/scala/play/api/libs/ws/ahc/AhcWSRequestSpec.scala b/play-ahc-ws-standalone/src/test/scala/play/api/libs/ws/ahc/AhcWSRequestSpec.scala index 98c69d31..2485f0e7 100644 --- a/play-ahc-ws-standalone/src/test/scala/play/api/libs/ws/ahc/AhcWSRequestSpec.scala +++ b/play-ahc-ws-standalone/src/test/scala/play/api/libs/ws/ahc/AhcWSRequestSpec.scala @@ -14,8 +14,7 @@ import play.api.libs.oauth.{ ConsumerKey, OAuthCalculator, RequestToken } import play.api.libs.ws._ import play.shaded.ahc.io.netty.handler.codec.http.HttpHeaders import play.shaded.ahc.org.asynchttpclient.Realm.AuthScheme -import play.shaded.ahc.io.netty.handler.codec.http.cookie.{ Cookie => AHCCookie } -import play.shaded.ahc.org.asynchttpclient.{ Param, RequestBuilderBase, SignatureCalculator, Request => AHCRequest } +import play.shaded.ahc.org.asynchttpclient.{ Param, SignatureCalculator, Request => AHCRequest } import scala.collection.JavaConverters._ import scala.concurrent.duration._ @@ -108,6 +107,47 @@ class AhcWSRequestSpec extends Specification with Mockito with AfterAll with Def } + "with unencoded values" in { + + /** + * All tests in this block produce the same request. + */ + def verify(input: StandaloneWSClient => StandaloneWSRequest) = { + withClient { client => + val request = input(client) + + val uri = request.uri + uri.getPath === "/|" + uri.getQuery.split('&').toSeq must contain(exactly("!=", "#=$", "^=*", "^=(")) + + request.url === "http://www.example.com/%7C" + request.queryString must contain("!" -> Seq("")) + request.queryString must contain("#" -> Seq("$")) + request.queryString.get("^") must beSome.which(_ must contain(exactly("*", "("))) + } + } + + "path=plain qp=plain" in verify { client => + client + .url("http://www.example.com/|?!") + .addQueryStringParameters("#" -> "$") + .addQueryStringParameters("^" -> "*", "^" -> "(") + } + + "path=enc qp=plain" in verify { client => + client + .url("http://www.example.com/|?%21") + .addQueryStringParameters("#" -> "$") + .addQueryStringParameters("^" -> "*", "^" -> "(") + } + + "path=enc qp=enc" in verify { client => + client + .url("http://www.example.com/%7C?%21") + .addQueryStringParameters("#" -> "$") + .addQueryStringParameters("^" -> "*", "^" -> "(") + } + } } "For Cookies" in { diff --git a/play-ws-standalone/src/main/scala/play/api/libs/ws/StandaloneWSClient.scala b/play-ws-standalone/src/main/scala/play/api/libs/ws/StandaloneWSClient.scala index cc420d52..2639ba93 100644 --- a/play-ws-standalone/src/main/scala/play/api/libs/ws/StandaloneWSClient.scala +++ b/play-ws-standalone/src/main/scala/play/api/libs/ws/StandaloneWSClient.scala @@ -21,10 +21,14 @@ trait StandaloneWSClient extends Closeable { /** * Generates a request. Throws IllegalArgumentException if the URL is invalid. * + * Query params may be present in the url either encoded or unencoded, + * which will be available in the [[StandaloneWSRequest.queryString]] + * of the returned [[StandaloneWSRequest]]. + * * @param url The base URL to make HTTP requests to. * @return a request */ - @throws[IllegalArgumentException] + @throws[IllegalArgumentException]("if the URL is invalid") def url(url: String): StandaloneWSRequest /**