Skip to content

Commit

Permalink
refact the logic
Browse files Browse the repository at this point in the history
  • Loading branch information
rickwu666666 committed Dec 18, 2023
1 parent c8ddebf commit a44036c
Showing 1 changed file with 196 additions and 75 deletions.
271 changes: 196 additions & 75 deletions checkbox-provider-ce-oem/bin/tcp_multi_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,89 @@
)


class PortOutputer():
def __init__(self, port=int,
message=str,
list_status={}):
self.port = port
self.message = message
self.list_status = list_status
self._status = None
self._total_period = None
self._avg_time_period = None
self._max_time_period = None
self._min_time_period = None

@property
def status(self):
if not self.list_status:
return "ERROR"
else:
for check in self.list_status.values():
if check['status'] is False:
self.message = "Received payload incorrect!"
return "FAIL"
self.message = "Received payload correct!"
return "PASS"

@property
def total_period(self):
if not self.list_status:
return self._total_period
else:
total_value = timedelta()
for value in self.list_status.values():
total_value += value.get('time')
return total_value

@property
def avg_time_period(self):
if not self.list_status:
return self._total_period
else:
sum_value = timedelta()
for value in self.list_status.values():
sum_value += value.get('time')
return (sum_value / len(self.list_status))

@property
def max_time_period(self):
if not self.list_status:
return self._total_period
else:
max_value = timedelta()
for value in self.list_status.values():
current_value = value.get('time')
max_value = max(max_value, current_value)
return max_value

@property
def min_time_period(self):
if not self.list_status:
return self._total_period
else:
min_value = None
for value in self.list_status.values():
current_value = value.get('time')
if min_value is None:
min_value = current_value
else:
min_value = min(min_value, current_value)
return min_value

def generate_result(self):
return {
'port': self.port,
'status': self.status,
'message': self.message,
'list_status': self.list_status,
'total_period': self.total_period,
'avg_period': self.avg_time_period,
'max_period': self.max_time_period,
'min_period': self.min_time_period,
}


def server(start_port, end_port):
"""
Start the server to listen on a range of ports.
Expand All @@ -34,33 +117,35 @@ def handle_port(port):
Args:
- port (int): Port to handle connections.
"""
# with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket:
server = ("0.0.0.0", port)
with socket.create_server(server) as server_socket:
# Set send buffer size to 4096
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 4096)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_socket.listen()

logging.info("Server listening on port {}".format(port))

while True:
conn, addr = server_socket.accept()
try:
with conn:
logging.info("Connected by {}.".format(addr))
while True:
data = conn.recv(4096)
if data:
conn.sendall(data)
else:
break
finally:
conn.close()
try:
with socket.create_server(server) as server_socket:
# Set send buffer size to 4096
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 4096)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_socket.listen()

logging.info("Server listening on port {}".format(port))

while True:
try:
conn, addr = server_socket.accept()
with conn:
logging.info("Connected by {}.".format(addr))
while True:
data = conn.recv(4096)
if data:
conn.sendall(data)
else:
break
except Exception as e:
logging.error("Error handling connection: {}".format(e))
except Exception as e:
logging.error("{}: An unexpected error occurred for port {}"
.format(e, port))


def client(host, start_port, end_port, payload, start_time):
time = datetime.now()
"""
Start the client to connect to a range of server ports.
Expand All @@ -69,36 +154,50 @@ def client(host, start_port, end_port, payload, start_time):
- start_port (int): Starting port for the client.
- end_port (int): Ending port for the client.
- payload (str): Payload to send to the server.
- done_event (threading.Event): Event to signal when the client is done.
- done_event (threading.Event): Event to single when the client is done.
- start_time (datetime): Time until which the client should run.
"""
global global_results
time = datetime.now()
threads = []
for port in range(start_port, end_port + 1):
thread = threading.Thread(target=send_payload,
args=(host, port, payload, start_time))
args=(host,
port,
payload,
start_time))
threads.append(thread)
thread.start()

# Wait for all client threads to finish
for thread in threads:
thread.join()

fail_port = [x for x in results if "FAIL" in x]
error_port = [x for x in results if "ERROR" in x]
if not (fail_port or error_port):
logging.info("TCP connections test pass!")
final = 0
for x in global_results:
if ("FAIL") in x['status']:
final = 1
logging.error("Fail on port {}.\n"
"{}\n"
"Detail:\n{}"
.format(x['port'],
x['message'],
"\n".join("{}: period: {} status: {}"
.format(key,
value['time'],
value['status']) for key,
value in x['list_status'].items()))
)
elif ("ERROR") in x['status']:
final = 1
logging.error("Not able to connect on port {}."
"{}"
.format(x['port'],
x['message']))
if final:
raise RuntimeError("TCP payload test fail!")
else:
if fail_port:
for x in fail_port:
logging.error("Fail on port {}.".format(x.split(":")[0]))
raise RuntimeError("TCP payload test fail!")
if error_port:
for x in error_port:
logging.error("Not able to connect on port {}."
.format(x.split(":")[0]))
raise RuntimeError("TCP connection fail!")
logging.info("Run TCP multi-connections test in {}".
format(datetime.now() - time))
logging.info("Run TCP multi-connections test in {}"
.format(datetime.now() - time))


def send_payload(host, port, payload, start_time):
Expand All @@ -111,39 +210,61 @@ def send_payload(host, port, payload, start_time):
- payload (str): Payload to send to the server.
- start_time (datetime): Time until which the client should run.
"""
try:
server_host = (host, port)
with socket.create_connection(server_host) as client_socket:
# Set send buffer size to 4096
client_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 4096)
logging.info("Connect to port {}".format(port))
while datetime.now() < start_time:
time.sleep(1)
logging.info("Sending payload to port {}.".format(port))
status = 0
# Sending payload for 30 sec after start sending.
while datetime.now() < start_time + timedelta(seconds=30):
received_data = 0
client_socket.sendall(payload.encode())
while received_data < len(payload):
data = client_socket.recv(4096)
received_data += len(data)
if received_data != len(payload):
status = 1
logging.info("Received payload from {}.".
format(server_host))
if status:
results.append("{}:FAIL".format(port))
else:
results.append("{}:PASS".format(port))
client_socket.close()
except socket.error as e:
logging.error("{} on port {}".format(e, port))
results.append("{}:ERROR".format(port))
except Exception as e:
logging.error("{}: An unexpected error occurred for port {}"
.format(e, port))
results.append("{}:ERROR".format(port))
global global_results
port_result = PortOutputer(port=port)
# Retry connect to server port for 5 times.
for _ in range(5):
try:
server_host = (host, port)
with socket.create_connection(server_host) as client_socket:
# Set send buffer size to 4096
client_socket.setsockopt(socket.SOL_SOCKET,
socket.SO_SNDBUF, 4096)
logging.info("Connect to port {}".format(port))
# Sleep until start time
start_time = start_time - datetime.now()
time.sleep(start_time.total_seconds())
logging.info("Sending payload to port {}.".format(port))
# Sending payload for 10 times
status_all = {}
for x in range(10):
single_start = datetime.now()
client_socket.sendall(payload.encode())
received_data = ""
while len(received_data) < len(payload):
# set socket time out for 30 seconds,
# in case recv hang.
client_socket.settimeout(30)
try:
data = client_socket.recv(4096)
if not data:
break
received_data += data.decode()
except TimeoutError:
break
single_end = datetime.now() - single_start
if received_data != payload:
status_all[x] = {'time': single_end,
'status': False}
else:
status_all[x] = {'time': single_end,
'status': True}
logging.info("Received payload from {}.".
format(server_host))
port_result.port = port
port_result.list_status = status_all
client_socket.close()
break
except socket.error as e:
logging.error("{} on {}".format(e, port))
port_result.message = e
port_result.port = port
except Exception as e:
logging.error("{} on {}".format(e, port))
port_result.message = e
port_result.port = port
time.sleep(5)
global_results.append(port_result.generate_result())


if __name__ == "__main__":
Expand Down Expand Up @@ -204,7 +325,7 @@ def send_payload(host, port, payload, start_time):
help="Ending port for the client")
args = parser.parse_args()

results = []
global_results = []
# Ramp up time to wait until all ports are connected before
# starting to send the payload.
start_time = datetime.now() + timedelta(seconds=20)
Expand Down

0 comments on commit a44036c

Please sign in to comment.