Skip to content

Commit

Permalink
Allow traffic from self in Bastion SG (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
tjohnes authored Jun 24, 2024
1 parent bf4dee8 commit ec8b22e
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 22 deletions.
10 changes: 10 additions & 0 deletions modules/aws/bastion/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ resource "aws_security_group" "this" {
}
}

resource "aws_vpc_security_group_ingress_rule" "self" {
security_group_id = aws_security_group.this.id

description = "Ingress traffic from self"
referenced_security_group_id = aws_security_group.this.id
from_port = -1
to_port = -1
ip_protocol = -1
}

resource "aws_vpc_security_group_ingress_rule" "ssh" {
for_each = toset(var.remote_access_cidr)

Expand Down
2 changes: 1 addition & 1 deletion tests/ut/terraform/bastion/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ variable "name" {
variable "remote_access_cidr" {
description = "Allowed CIDR blocks for external SSH access to the Bastion instance"
type = list(string)
default = ["0.0.0.0/0"]
default = []
nullable = false
}

Expand Down
63 changes: 42 additions & 21 deletions tests/ut/test_bastion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import uuid
from pathlib import Path
from typing import Any
from typing import Any, Sequence

import botocore.exceptions
import pytest
Expand Down Expand Up @@ -78,31 +78,52 @@ def base_vars(subnet: Subnet, key_pair: KeyPair) -> dict[str, Any]:
}


def _assert_sgrs(sg: SecurityGroup, remote_access_cidr: list[str]) -> None:
def _assert_sgrs(
sg: SecurityGroup,
remote_access_cidr: Sequence[str] | None = None,
) -> None:
"""Assert the 'bastion' security group has the expected ingress rules."""
assert len(sg.ip_permissions) == 2
# We expect at least one ingress SGR for traffic from self.
expected_sgr_count = 1

if remote_access_cidr:
# And two more if any remote access CIDRs are specified; one for ICMP,
# and one for SSH traffic.
expected_sgr_count += 2

assert len(sg.ip_permissions) == expected_sgr_count

self_sgr_found = False
icmp_sgr_found = False
ssh_sgr_found = False
for sgr in sg.ip_permissions:
assert {x["CidrIp"] for x in sgr["IpRanges"]} == set(
remote_access_cidr,
)

if sgr["IpProtocol"] == "icmp":
assert not icmp_sgr_found
icmp_sgr_found = True
assert sgr["FromPort"] == -1
assert sgr["ToPort"] == -1

elif sgr["IpProtocol"] == "tcp":
assert not ssh_sgr_found
ssh_sgr_found = True
assert sgr["FromPort"] == 22
assert sgr["ToPort"] == 22

for sgr in sg.ip_permissions:
if sgr["IpProtocol"] == "-1":
assert not self_sgr_found
assert len(sgr["UserIdGroupPairs"]) == 1
assert sgr["UserIdGroupPairs"][0]["GroupId"] == sg.id
self_sgr_found = True
else:
raise AssertionError(f"unexpected protocol '{sgr['IpProtocol']}'")
assert {x["CidrIp"] for x in sgr["IpRanges"]} == set(
remote_access_cidr,
)

if sgr["IpProtocol"] == "icmp":
assert not icmp_sgr_found
icmp_sgr_found = True
assert sgr["FromPort"] == -1
assert sgr["ToPort"] == -1

elif sgr["IpProtocol"] == "tcp":
assert not ssh_sgr_found
ssh_sgr_found = True
assert sgr["FromPort"] == 22
assert sgr["ToPort"] == 22

else:
raise AssertionError(
f"unexpected protocol '{sgr['IpProtocol']}'",
)


def test_defaults(
Expand All @@ -129,7 +150,7 @@ def test_defaults(

assert len(instance.security_groups) == 1
sg = ec2.SecurityGroup(instance.security_groups[0]["GroupId"])
_assert_sgrs(sg, ["0.0.0.0/0"])
_assert_sgrs(sg)


def test_instance_type(
Expand Down

0 comments on commit ec8b22e

Please sign in to comment.