Skip to content
This repository was archived by the owner on Feb 26, 2025. It is now read-only.

Commit 7c365f1

Browse files
committed
Refactor: remove some logic from handlers into helpers
1 parent 0431c67 commit 7c365f1

File tree

4 files changed

+60
-20
lines changed

4 files changed

+60
-20
lines changed

hpc_provisioner/src/hpc_provisioner/aws_queries.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,6 @@ class CouldNotDetermineEFSException(Exception):
4242
"""
4343

4444

45-
class CouldNotDetermineKeyPairException(Exception):
46-
"""Indicates that we either found too many or no keypairs with tags HPC_Goal:compute_cluster"""
47-
48-
4945
def get_cluster_name(vlab_id: str, project_id: str) -> str:
5046
return f"pcluster-{vlab_id}-{project_id}"
5147

@@ -67,6 +63,17 @@ def create_keypair(ec2_client, vlab_id, project_id, tags) -> dict:
6763
)
6864

6965

66+
def store_private_key(sm_client, vlab_id, project_id, ssh_keypair):
67+
if "KeyMaterial" in ssh_keypair:
68+
secret = create_secret(
69+
sm_client, vlab_id, project_id, ssh_keypair["KeyName"], ssh_keypair["KeyMaterial"]
70+
)
71+
else:
72+
secret = get_secret(sm_client, ssh_keypair["KeyName"])
73+
74+
return secret
75+
76+
7077
def create_secret(sm_client, vlab_id, project_id, secret_name, secret_value):
7178
secret = sm_client.create_secret(
7279
Name=secret_name,

hpc_provisioner/src/hpc_provisioner/handlers.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import boto3
77
from pcluster.api.errors import NotFoundException
88

9-
from hpc_provisioner.aws_queries import create_keypair, create_secret, get_secret
9+
from hpc_provisioner.aws_queries import create_keypair, store_private_key
1010
from hpc_provisioner.constants import (
1111
BILLING_TAG_KEY,
1212
BILLING_TAG_VALUE,
@@ -79,12 +79,7 @@ def pcluster_create_request_handler(event, _context=None):
7979
],
8080
)
8181

82-
if "KeyMaterial" in ssh_keypair:
83-
secret = create_secret(
84-
sm_client, vlab_id, project_id, ssh_keypair["KeyName"], ssh_keypair["KeyMaterial"]
85-
)
86-
else:
87-
secret = get_secret(sm_client, ssh_keypair["KeyName"])
82+
secret = store_private_key(sm_client, vlab_id, project_id, ssh_keypair)
8883

8984
logger.debug("calling create lambda async")
9085
boto3.client("lambda").invoke_async(

hpc_provisioner/tests/test_aws_queries.py

+44
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
get_efs,
1515
get_secret,
1616
get_security_group,
17+
store_private_key,
1718
)
1819
from hpc_provisioner.constants import (
1920
BILLING_TAG_KEY,
@@ -362,3 +363,46 @@ def test_create_secret():
362363
{"Key": BILLING_TAG_KEY, "Value": BILLING_TAG_VALUE},
363364
],
364365
)
366+
367+
368+
@pytest.mark.parametrize("keypair_exists", [True, False])
369+
@pytest.mark.parametrize("secret_exists", [True, False])
370+
def test_store_private_key(keypair_exists, secret_exists):
371+
if secret_exists and not keypair_exists:
372+
pytest.skip(
373+
"New keypair with existing secret: sm_client.create_secret will raise "
374+
"on trying to create a secret that already exists."
375+
)
376+
mock_sm_client = MagicMock()
377+
vlab_id = "testvlab"
378+
project_id = "testproject"
379+
secret_name = key_name = f"pcluster-{vlab_id}-{project_id}"
380+
if not keypair_exists:
381+
ssh_keypair = {"KeyMaterial": "supersecret", "KeyName": key_name}
382+
else:
383+
ssh_keypair = {"KeyName": key_name}
384+
385+
if not keypair_exists:
386+
if not secret_exists:
387+
print("Keypair created, secret does not exist yet")
388+
store_private_key(mock_sm_client, vlab_id, project_id, ssh_keypair)
389+
mock_sm_client.create_secret.assert_called_once_with(
390+
Name=secret_name,
391+
Description=f"SSH Key for cluster for vlab {vlab_id}, project {project_id}",
392+
SecretString=ssh_keypair["KeyMaterial"],
393+
Tags=[
394+
{"Key": VLAB_TAG_KEY, "Value": vlab_id},
395+
{"Key": PROJECT_TAG_KEY, "Value": project_id},
396+
{"Key": BILLING_TAG_KEY, "Value": BILLING_TAG_VALUE},
397+
],
398+
)
399+
elif secret_exists:
400+
print("Both already exist")
401+
mock_sm_client.list_secrets.return_value = {"SecretList": ["somesecret"]}
402+
retrieved_secret = store_private_key(mock_sm_client, vlab_id, project_id, ssh_keypair)
403+
assert retrieved_secret == "somesecret"
404+
else:
405+
print("Keypair already existed but was not stored in secretsmanager yet")
406+
mock_sm_client.list_secrets.return_value = {"SecretList": []}
407+
with pytest.raises(RuntimeError):
408+
store_private_key(mock_sm_client, vlab_id, project_id, ssh_keypair)

hpc_provisioner/tests/test_resource_provisioner.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -172,15 +172,9 @@ def test_post(patched_boto3, post_event, key_exists):
172172
"KeyName": test_cluster_name,
173173
}
174174

175-
with patch("hpc_provisioner.handlers.create_secret") as patched_create_secret:
176-
patched_create_keypair.return_value = {
177-
"KeyMaterial": "secret_stuff",
178-
"KeyName": test_cluster_name,
179-
}
180-
patched_create_secret.return_value = {"ARN": "secret ARN"}
181-
with patch("hpc_provisioner.handlers.get_secret") as patched_get_secret:
182-
patched_get_secret.return_value = {"ARN": "secret ARN"}
183-
actual_response = handlers.pcluster_create_request_handler(post_event)
175+
with patch("hpc_provisioner.handlers.store_private_key") as patched_store_private_key:
176+
patched_store_private_key.return_value = {"ARN": "secret ARN"}
177+
actual_response = handlers.pcluster_create_request_handler(post_event)
184178
mock_client.invoke_async.assert_called_with(
185179
FunctionName="hpc-resource-provisioner-creator",
186180
InvokeArgs=json.dumps(

0 commit comments

Comments
 (0)