Skip to content

Commit

Permalink
Merge branch 'tcp_proxy'
Browse files Browse the repository at this point in the history
  • Loading branch information
cortesi committed Feb 6, 2014
2 parents 404d4bb + 7fc544b commit 3d52d16
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 113 deletions.
2 changes: 1 addition & 1 deletion netlib/certutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def altnames(self):


def get_remote_cert(host, port, sni):
c = tcp.TCPClient(host, port)
c = tcp.TCPClient((host, port))
c.connect()
c.convert_to_ssl(sni=sni)
return c.cert
4 changes: 4 additions & 0 deletions netlib/odict.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re, copy


def safe_subn(pattern, repl, target, *args, **kwargs):
"""
There are Unicode conversion problems with re.subn. We try to smooth
Expand Down Expand Up @@ -98,6 +99,9 @@ def items(self):
def _get_state(self):
return [tuple(i) for i in self.lst]

def _load_state(self, state):
self.list = [list(i) for i in state]

@classmethod
def _from_state(klass, state):
return klass([list(i) for i in state])
Expand Down
186 changes: 104 additions & 82 deletions netlib/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,16 +173,88 @@ def readline(self, size = None):
return result


class TCPClient:
class Address(object):
"""
This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information.
"""
def __init__(self, address, use_ipv6=False):
self.address = tuple(address)
self.use_ipv6 = use_ipv6

@classmethod
def wrap(cls, t):
if isinstance(t, cls):
return t
else:
return cls(t)

def __call__(self):
return self.address

@property
def host(self):
return self.address[0]

@property
def port(self):
return self.address[1]

@property
def use_ipv6(self):
return self.family == socket.AF_INET6

@use_ipv6.setter
def use_ipv6(self, b):
self.family = socket.AF_INET6 if b else socket.AF_INET

def __eq__(self, other):
other = Address.wrap(other)
return (self.address, self.family) == (other.address, other.family)


class SocketCloseMixin(object):
def finish(self):
self.finished = True
try:
if not getattr(self.wfile, "closed", False):
self.wfile.flush()
self.close()
self.wfile.close()
self.rfile.close()
except (socket.error, NetLibDisconnect):
# Remote has disconnected
pass

def close(self):
"""
Does a hard close of the socket, i.e. a shutdown, followed by a close.
"""
try:
if self.ssl_established:
self.connection.shutdown()
self.connection.sock_shutdown(socket.SHUT_WR)
else:
self.connection.shutdown(socket.SHUT_WR)
#Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data could lead to an immediate RST being sent.
#http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html
while self.connection.recv(4096):
pass
self.connection.close()
except (socket.error, SSL.Error, IOError):
# Socket probably already closed
pass


class TCPClient(SocketCloseMixin):
rbufsize = -1
wbufsize = -1
def __init__(self, host, port, source_address=None, use_ipv6=False):
self.host, self.port = host, port
self.source_address = source_address
self.use_ipv6 = use_ipv6
def __init__(self, address, source_address=None):
self.address = Address.wrap(address)
self.source_address = Address.wrap(source_address) if source_address else None
self.connection, self.rfile, self.wfile = None, None, None
self.cert = None
self.ssl_established = False
self.sni = None

def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None):
"""
Expand All @@ -200,6 +272,7 @@ def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None)
self.connection = SSL.Connection(context, self.connection)
self.ssl_established = True
if sni:
self.sni = sni
self.connection.set_tlsext_host_name(sni)
self.connection.set_connect_state()
try:
Expand All @@ -212,14 +285,14 @@ def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None)

def connect(self):
try:
connection = socket.socket(socket.AF_INET6 if self.use_ipv6 else socket.AF_INET, socket.SOCK_STREAM)
connection = socket.socket(self.address.family, socket.SOCK_STREAM)
if self.source_address:
connection.bind(self.source_address)
connection.connect((self.host, self.port))
connection.bind(self.source_address())
connection.connect(self.address())
self.rfile = Reader(connection.makefile('rb', self.rbufsize))
self.wfile = Writer(connection.makefile('wb', self.wbufsize))
except (socket.error, IOError), err:
raise NetLibError('Error connecting to "%s": %s' % (self.host, err))
raise NetLibError('Error connecting to "%s": %s' % (self.address.host, err))
self.connection = connection

def settimeout(self, n):
Expand All @@ -228,43 +301,24 @@ def settimeout(self, n):
def gettimeout(self):
return self.connection.gettimeout()

def close(self):
"""
Does a hard close of the socket, i.e. a shutdown, followed by a close.
"""
try:
if self.ssl_established:
self.connection.shutdown()
self.connection.sock_shutdown(socket.SHUT_WR)
else:
self.connection.shutdown(socket.SHUT_WR)
#Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data could lead to an immediate RST being sent.
#http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html
while self.connection.recv(4096):
pass
self.connection.close()
except (socket.error, SSL.Error, IOError):
# Socket probably already closed
pass


class BaseHandler:
class BaseHandler(SocketCloseMixin):
"""
The instantiator is expected to call the handle() and finish() methods.
"""
rbufsize = -1
wbufsize = -1
def __init__(self, connection, client_address, server):

def __init__(self, connection, address, server):
self.connection = connection
self.address = Address.wrap(address)
self.server = server
self.rfile = Reader(self.connection.makefile('rb', self.rbufsize))
self.wfile = Writer(self.connection.makefile('wb', self.wbufsize))

self.client_address = client_address
self.server = server
self.finished = False
self.ssl_established = False

self.clientcert = None

def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=False, cipher_list=None):
Expand Down Expand Up @@ -318,66 +372,34 @@ def ver(*args):
self.rfile.set_descriptor(self.connection)
self.wfile.set_descriptor(self.connection)

def finish(self):
self.finished = True
try:
if not getattr(self.wfile, "closed", False):
self.wfile.flush()
self.close()
self.wfile.close()
self.rfile.close()
except (socket.error, NetLibDisconnect):
# Remote has disconnected
pass

def handle(self): # pragma: no cover
raise NotImplementedError

def settimeout(self, n):
self.connection.settimeout(n)

def close(self):
"""
Does a hard close of the socket, i.e. a shutdown, followed by a close.
"""
try:
if self.ssl_established:
self.connection.shutdown()
self.connection.sock_shutdown(socket.SHUT_WR)
else:
self.connection.shutdown(socket.SHUT_WR)
# Section 4.2.2.13 of RFC 1122 tells us that a close() with any
# pending readable data could lead to an immediate RST being sent.
# http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html
while self.connection.recv(4096):
pass
except (socket.error, SSL.Error):
# Socket probably already closed
pass
self.connection.close()


class TCPServer:
request_queue_size = 20
def __init__(self, server_address, use_ipv6=False):
self.server_address = server_address
self.use_ipv6 = use_ipv6
def __init__(self, address):
self.address = Address.wrap(address)
self.__is_shut_down = threading.Event()
self.__shutdown_request = False
self.socket = socket.socket(socket.AF_INET6 if self.use_ipv6 else socket.AF_INET, socket.SOCK_STREAM)
self.socket = socket.socket(self.address.family, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind(self.server_address)
self.server_address = self.socket.getsockname()
self.port = self.server_address[1]
self.socket.bind(self.address())
self.address = Address.wrap(self.socket.getsockname())
self.socket.listen(self.request_queue_size)

def request_thread(self, request, client_address):
def connection_thread(self, connection, client_address):
client_address = Address(client_address)
try:
self.handle_connection(request, client_address)
request.close()
self.handle_client_connection(connection, client_address)
except:
self.handle_error(request, client_address)
request.close()
self.handle_error(connection, client_address)
finally:
connection.close()

def serve_forever(self, poll_interval=0.1):
self.__is_shut_down.clear()
Expand All @@ -391,10 +413,10 @@ def serve_forever(self, poll_interval=0.1):
else:
raise
if self.socket in r:
request, client_address = self.socket.accept()
connection, client_address = self.socket.accept()
t = threading.Thread(
target = self.request_thread,
args = (request, client_address)
target = self.connection_thread,
args = (connection, client_address)
)
t.setDaemon(1)
t.start()
Expand All @@ -410,18 +432,18 @@ def shutdown(self):

def handle_error(self, request, client_address, fp=sys.stderr):
"""
Called when handle_connection raises an exception.
Called when handle_client_connection raises an exception.
"""
# If a thread has persisted after interpreter exit, the module might be
# none.
if traceback:
exc = traceback.format_exc()
print >> fp, '-'*40
print >> fp, "Error in processing of request from %s:%s"%client_address
print >> fp, "Error in processing of request from %s:%s" % (client_address.host, client_address.port)
print >> fp, exc
print >> fp, '-'*40

def handle_connection(self, request, client_address): # pragma: no cover
def handle_client_connection(self, conn, client_address): # pragma: no cover
"""
Called after client connection.
"""
Expand Down
11 changes: 5 additions & 6 deletions netlib/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,18 @@ class ServerTestBase:
ssl = None
handler = None
addr = ("localhost", 0)
use_ipv6 = False

@classmethod
def setupAll(cls):
cls.q = Queue.Queue()
s = cls.makeserver()
cls.port = s.port
cls.port = s.address.port
cls.server = ServerThread(s)
cls.server.start()

@classmethod
def makeserver(cls):
return TServer(cls.ssl, cls.q, cls.handler, cls.addr, cls.use_ipv6)
return TServer(cls.ssl, cls.q, cls.handler, cls.addr)

@classmethod
def teardownAll(cls):
Expand All @@ -41,16 +40,16 @@ def last_handler(self):


class TServer(tcp.TCPServer):
def __init__(self, ssl, q, handler_klass, addr, use_ipv6):
def __init__(self, ssl, q, handler_klass, addr):
"""
ssl: A {cert, key, v3_only} dict.
"""
tcp.TCPServer.__init__(self, addr, use_ipv6=use_ipv6)
tcp.TCPServer.__init__(self, addr)
self.ssl, self.q = ssl, q
self.handler_klass = handler_klass
self.last_handler = None

def handle_connection(self, request, client_address):
def handle_client_connection(self, request, client_address):
h = self.handler_klass(request, client_address, self)
self.last_handler = h
if self.ssl:
Expand Down
15 changes: 10 additions & 5 deletions netlib/wsgi.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import cStringIO, urllib, time, traceback
import odict
import odict, tcp


class ClientConn:
def __init__(self, address):
self.address = address
self.address = tcp.Address.wrap(address)


class Flow:
def __init__(self, client_conn):
self.client_conn = client_conn


class Request:
def __init__(self, client_conn, scheme, method, path, headers, content):
self.scheme, self.method, self.path = scheme, method, path
self.headers, self.content = headers, content
self.client_conn = client_conn
self.flow = Flow(client_conn)


def date_time_string():
Expand Down Expand Up @@ -60,8 +65,8 @@ def make_environ(self, request, errsoc, **extra):
'SERVER_PROTOCOL': "HTTP/1.1",
}
environ.update(extra)
if request.client_conn.address:
environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = request.client_conn.address
if request.flow.client_conn.address:
environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = request.flow.client_conn.address()

for key, value in request.headers.items():
key = 'HTTP_' + key.upper().replace('-', '_')
Expand Down
2 changes: 1 addition & 1 deletion test/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ class TestReadResponseNoContentLength(test.ServerTestBase):
handler = NoContentLengthHTTPHandler

def test_no_content_length(self):
c = tcp.TCPClient("127.0.0.1", self.port)
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
httpversion, code, msg, headers, content = http.read_response(c.rfile, "GET", None)
assert content == "bar\r\n\r\n"
Expand Down
Loading

0 comments on commit 3d52d16

Please sign in to comment.