Skip to content

Commit

Permalink
Added RSA SSH key type as default.
Browse files Browse the repository at this point in the history
Added docstring to the function explaining the behavior.
Formatted the file with auto-formatter.
  • Loading branch information
narmaku committed Jan 23, 2025
1 parent 0da7c64 commit 6b18741
Showing 1 changed file with 53 additions and 27 deletions.
80 changes: 53 additions & 27 deletions lib/ssh_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,27 @@
from threading import Thread


def generate_ssh_key_pair(ssh_key_path):
def generate_ssh_key_pair(ssh_key_path, key_type="rsa"):
"""
Generates a new SSH key pair at the specified path.
If a key already exists in the specified path, it will be removed first.
:param ssh_key_path: The path where the private and public keys will be saved.
:type ssh_key_path: str
:param key_type: The type of key to generate. Defaults to 'rsa'.
:type key_type: str, optional
Supported types:
- rsa: 2048-bit RSA key (default).
- rsa1: 512-bit RSA key.
- dsa: 1024-bit DSA key.
- ecdsa: Elliptic curve DSA key (secp256r1, by default).
- ed25519: Ed25519 key.
:return: None
"""
if os.path.exists(ssh_key_path):
os.remove(ssh_key_path)

os.system(f'ssh-keygen -f "{ssh_key_path}" -N "" -q')
os.system(f'ssh-keygen -t {key_type} -f "{ssh_key_path}" -N "" -q')


def generate_instances_ssh_config(ssh_key_path, ssh_config_file, instances):
Expand All @@ -19,16 +35,18 @@ def generate_instances_ssh_config(ssh_key_path, ssh_config_file, instances):
conf = sshconf.empty_ssh_config_file()

for inst in instances.values():
conf.add(inst['address'],
Hostname=inst['address'],
User=inst['username'],
Port=22,
IdentityFile=ssh_key_path,
StrictHostKeyChecking='no',
UserKnownHostsFile='/dev/null',
LogLevel='ERROR',
ConnectTimeout=30,
ConnectionAttempts=5)
conf.add(
inst["address"],
Hostname=inst["address"],
User=inst["username"],
Port=22,
IdentityFile=ssh_key_path,
StrictHostKeyChecking="no",
UserKnownHostsFile="/dev/null",
LogLevel="ERROR",
ConnectTimeout=30,
ConnectionAttempts=5,
)

conf.write(ssh_config_file)

Expand All @@ -44,14 +62,16 @@ def wait_for_host_ssh_up(host_address, timeout_seconds):
while time.time() < start_time + timeout_seconds:
tick = time.time()
if (os.system(f'ssh-keyscan "{host_address}" > /dev/null 2>&1') >> 8) == 0:
print(f'{host_address} SSH is up! ({time.time() - start_time} seconds)')
print(f"{host_address} SSH is up! ({time.time() - start_time} seconds)")
return
else:
time_diff_seconds = int(time.time() - tick)
time.sleep(max(0, (1 - time_diff_seconds)))

print(f'Timeout while waiting for {host_address} to be SSH-ready ({timeout_seconds} seconds).')
print('AWS: Check if this account has the appropiate inbound rules for this region')
print(
f"Timeout while waiting for {host_address} to be SSH-ready ({timeout_seconds} seconds)."
)
print("AWS: Check if this account has the appropiate inbound rules for this region")
exit(1)


Expand All @@ -77,37 +97,43 @@ def add_ssh_keys_to_instances(instances, ssh_config_file):

threads = []
for inst in instances.values():
t = Thread(target=__copy_team_ssh_keys_to_instance,
args=[inst, ssh_config_file, team_ssh_keys])
t = Thread(
target=__copy_team_ssh_keys_to_instance,
args=[inst, ssh_config_file, team_ssh_keys],
)
t.start()
threads.append(t)

[t.join() for t in threads]


def __get_team_ssh_keys_by_path():
keys_dir = 'schutzbot/team_ssh_keys'
keys_dir = "schutzbot/team_ssh_keys"

keys = {}
for p in os.listdir(keys_dir):
key_file_path = os.path.join(keys_dir, p)
with open(key_file_path, 'r') as f:
with open(key_file_path, "r") as f:
keys[key_file_path] = f.read()

return keys


def __copy_team_ssh_keys_to_instance(instance, ssh_config_file, team_ssh_keys):
auth_keys = '~/.ssh/authorized_keys'
instance_address = instance['address']
username = instance['username']
auth_keys = "~/.ssh/authorized_keys"
instance_address = instance["address"]
username = instance["username"]

composed_echo_command = ';'.join([f'echo "{k}" >> {auth_keys}' for k in team_ssh_keys.values()])
composed_echo_command = ";".join(
[f'echo "{k}" >> {auth_keys}' for k in team_ssh_keys.values()]
)

ssh_command = (f'ssh -F "{ssh_config_file}" '
f'{username}@{instance_address} "{composed_echo_command}" > /dev/null 2>&1')
ssh_command = (
f'ssh -F "{ssh_config_file}" '
f'{username}@{instance_address} "{composed_echo_command}" > /dev/null 2>&1'
)

if (os.system(ssh_command) >> 8) == 0:
print(f'[{instance_address}] Public SSH key(s) copied successfully!')
print(f"[{instance_address}] Public SSH key(s) copied successfully!")
else:
print(f'[{instance_address}] WARNING: Could not copy public SSH key(s)')
print(f"[{instance_address}] WARNING: Could not copy public SSH key(s)")

0 comments on commit 6b18741

Please sign in to comment.