diff --git a/src/waitress/task.py b/src/waitress/task.py index d88ad019..f24fbe00 100644 --- a/src/waitress/task.py +++ b/src/waitress/task.py @@ -180,6 +180,18 @@ def has_body(self): or self.status.startswith("304") ) + def set_close_on_finish(self) -> None: + # if headers have not been written yet, tell the remote + # client we are closing the connection + if not self.wrote_header: + connection_close_header = None + for headername, headerval in self.response_headers: + if headername.capitalize() == "Connection": + connection_close_header = headerval.lower() + if connection_close_header is None: + self.response_headers.append(("Connection", "close")) + self.close_on_finish = True + def build_response_header(self): version = self.version # Figure out whether the connection should be closed. @@ -188,7 +200,6 @@ def build_response_header(self): content_length_header = None date_header = None server_header = None - connection_close_header = None for headername, headerval in self.response_headers: headername = "-".join([x.capitalize() for x in headername.split("-")]) @@ -205,47 +216,43 @@ def build_response_header(self): if headername == "Server": server_header = headerval - if headername == "Connection": - connection_close_header = headerval.lower() # replace with properly capitalized version response_headers.append((headername, headerval)) + # Overwrite the response headers we have with normalized ones + self.response_headers = response_headers + if ( content_length_header is None and self.content_length is not None and self.has_body ): content_length_header = str(self.content_length) - response_headers.append(("Content-Length", content_length_header)) - - def close_on_finish(): - if connection_close_header is None: - response_headers.append(("Connection", "close")) - self.close_on_finish = True + self.response_headers.append(("Content-Length", content_length_header)) if version == "1.0": if connection == "keep-alive": if not content_length_header: - close_on_finish() + self.set_close_on_finish() else: - response_headers.append(("Connection", "Keep-Alive")) + self.response_headers.append(("Connection", "Keep-Alive")) else: - close_on_finish() + self.set_close_on_finish() elif version == "1.1": if connection == "close": - close_on_finish() + self.set_close_on_finish() if not content_length_header: # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length # for any response with a status code of 1xx, 204 or 304. if self.has_body: - response_headers.append(("Transfer-Encoding", "chunked")) + self.response_headers.append(("Transfer-Encoding", "chunked")) self.chunked_response = True if not self.close_on_finish: - close_on_finish() + self.set_close_on_finish() # under HTTP 1.1 keep-alive is default, no need to set the header else: @@ -257,14 +264,12 @@ def close_on_finish(): if not server_header: if ident: - response_headers.append(("Server", ident)) + self.response_headers.append(("Server", ident)) else: - response_headers.append(("Via", ident or "waitress")) + self.response_headers.append(("Via", ident or "waitress")) if not date_header: - response_headers.append(("Date", build_http_date(self.start_time))) - - self.response_headers = response_headers + self.response_headers.append(("Date", build_http_date(self.start_time))) first_line = f"HTTP/{self.version} {self.status}" # NB: sorting headers needs to preserve same-named-header order @@ -350,11 +355,7 @@ def execute(self): status, headers, body = e.to_response(ident) self.status = status self.response_headers.extend(headers) - # We need to explicitly tell the remote client we are closing the - # connection, because self.close_on_finish is set, and we are going to - # slam the door in the clients face. - self.response_headers.append(("Connection", "close")) - self.close_on_finish = True + self.set_close_on_finish() self.content_length = len(body) self.write(body) @@ -388,7 +389,7 @@ def start_response(status, headers, exc_info=None): self.complete = True - if not status.__class__ is str: + if status.__class__ is not str: raise AssertionError("status %s is not a string" % status) if "\n" in status or "\r" in status: raise ValueError( @@ -399,11 +400,11 @@ def start_response(status, headers, exc_info=None): # Prepare the headers for output for k, v in headers: - if not k.__class__ is str: + if k.__class__ is not str: raise AssertionError( f"Header name {k!r} is not a string in {(k, v)!r}" ) - if not v.__class__ is str: + if v.__class__ is not str: raise AssertionError( f"Header value {v!r} is not a string in {(k, v)!r}" ) @@ -478,14 +479,14 @@ def start_response(status, headers, exc_info=None): # close the connection so the client isn't sitting around # waiting for more data when there are too few bytes # to service content-length - # unless it's a HEAD request in which case we don't expect - # to return any bytes regardless of the content length - self.close_on_finish = True - self.logger.warning( - "application returned too few bytes (%s) " - "for specified Content-Length (%s) via app_iter" - % (self.content_bytes_written, cl), - ) + self.set_close_on_finish() + if self.request.command != "HEAD": + self.logger.warning( + "application returned too few bytes (%s) " + "for specified Content-Length (%s) via app_iter", + self.content_bytes_written, + cl, + ) finally: if can_close_app_iter and hasattr(app_iter, "close"): app_iter.close()