Skip to content
This repository was archived by the owner on Feb 26, 2025. It is now read-only.

Commit 0431c67

Browse files
committed
Fix tests, new tests, code cleanup
1 parent 467b871 commit 0431c67

File tree

4 files changed

+139
-63
lines changed

4 files changed

+139
-63
lines changed

hpc_provisioner/src/hpc_provisioner/aws_queries.py

+31
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@
66
import boto3
77
from botocore.exceptions import ClientError
88

9+
from hpc_provisioner.constants import (
10+
BILLING_TAG_KEY,
11+
BILLING_TAG_VALUE,
12+
PROJECT_TAG_KEY,
13+
VLAB_TAG_KEY,
14+
)
915
from hpc_provisioner.dynamodb_actions import (
1016
SubnetAlreadyRegisteredException,
1117
dynamodb_client,
@@ -61,6 +67,31 @@ def create_keypair(ec2_client, vlab_id, project_id, tags) -> dict:
6167
)
6268

6369

70+
def create_secret(sm_client, vlab_id, project_id, secret_name, secret_value):
71+
secret = sm_client.create_secret(
72+
Name=secret_name,
73+
Description=f"SSH Key for cluster for vlab {vlab_id}, project {project_id}",
74+
SecretString=secret_value,
75+
Tags=[
76+
{"Key": VLAB_TAG_KEY, "Value": vlab_id},
77+
{"Key": PROJECT_TAG_KEY, "Value": project_id},
78+
{"Key": BILLING_TAG_KEY, "Value": BILLING_TAG_VALUE},
79+
],
80+
)
81+
82+
return secret
83+
84+
85+
def get_secret(sm_client, secret_name):
86+
existing_secrets = sm_client.list_secrets(Filters=[{"Key": "name", "Values": [secret_name]}])
87+
if secret_list := existing_secrets.get("SecretList", []):
88+
secret = secret_list[0]
89+
else:
90+
raise RuntimeError(f"Secret {secret_name} does not exist in SecretsManager")
91+
92+
return secret
93+
94+
6495
def get_efs(efs_client) -> str:
6596
"""
6697
Get the ID for the EFS for pclusters

hpc_provisioner/src/hpc_provisioner/handlers.py

+4-20
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import boto3
77
from pcluster.api.errors import NotFoundException
88

9-
from hpc_provisioner.aws_queries import create_keypair
9+
from hpc_provisioner.aws_queries import create_keypair, create_secret, get_secret
1010
from hpc_provisioner.constants import (
1111
BILLING_TAG_KEY,
1212
BILLING_TAG_VALUE,
@@ -80,27 +80,11 @@ def pcluster_create_request_handler(event, _context=None):
8080
)
8181

8282
if "KeyMaterial" in ssh_keypair:
83-
secret = sm_client.create_secret(
84-
Name=ssh_keypair["KeyName"],
85-
Description=f"SSH Key for cluster for vlab {vlab_id}, project {project_id}",
86-
SecretString=ssh_keypair["KeyMaterial"],
87-
Tags=[
88-
{"Key": VLAB_TAG_KEY, "Value": vlab_id},
89-
{"Key": PROJECT_TAG_KEY, "Value": project_id},
90-
{"Key": BILLING_TAG_KEY, "Value": BILLING_TAG_VALUE},
91-
],
83+
secret = create_secret(
84+
sm_client, vlab_id, project_id, ssh_keypair["KeyName"], ssh_keypair["KeyMaterial"]
9285
)
9386
else:
94-
existing_secrets = sm_client.list_secrets(
95-
Filters=[{"Key": "name", "Values": [ssh_keypair["KeyName"]]}]
96-
)
97-
if secret_list := existing_secrets["SecretList"]:
98-
secret = secret_list[0]
99-
else:
100-
raise RuntimeError(
101-
f"SSH Keypair {ssh_keypair['KeyName']} already exists in EC2 but "
102-
"was not stored in SecretsManager - unable to retrieve private key"
103-
)
87+
secret = get_secret(sm_client, ssh_keypair["KeyName"])
10488

10589
logger.debug("calling create lambda async")
10690
boto3.client("lambda").invoke_async(

hpc_provisioner/tests/test_aws_queries.py

+61-32
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,25 @@
22
from unittest.mock import MagicMock, call, patch
33

44
import pytest
5+
from botocore.exceptions import ClientError
56
from hpc_provisioner.aws_queries import (
67
CouldNotDetermineEFSException,
7-
CouldNotDetermineKeyPairException,
88
CouldNotDetermineSecurityGroupException,
99
OutOfSubnetsException,
1010
claim_subnet,
11+
create_keypair,
12+
create_secret,
1113
get_available_subnet,
1214
get_efs,
13-
get_keypair,
15+
get_secret,
1416
get_security_group,
1517
)
18+
from hpc_provisioner.constants import (
19+
BILLING_TAG_KEY,
20+
BILLING_TAG_VALUE,
21+
PROJECT_TAG_KEY,
22+
VLAB_TAG_KEY,
23+
)
1624
from hpc_provisioner.dynamodb_actions import SubnetAlreadyRegisteredException
1725

1826
logger = logging.getLogger("test_logger")
@@ -28,36 +36,6 @@
2836
logger.setLevel(logging.DEBUG)
2937

3038

31-
@pytest.mark.parametrize(
32-
"keypairs",
33-
[
34-
{"KeyPairs": [{"KeyName": "keypair-1"}]},
35-
],
36-
)
37-
def test_get_keypair(keypairs):
38-
mock_ec2_client = MagicMock()
39-
mock_ec2_client.describe_key_pairs.return_value = keypairs
40-
keypair = get_keypair(mock_ec2_client)
41-
assert keypair == keypairs["KeyPairs"][0]["KeyName"]
42-
43-
44-
@pytest.mark.parametrize(
45-
"keypairs",
46-
[
47-
{"KeyPairs": []},
48-
{"KeyPairs": ["keypair-1", "keypair-2"]},
49-
],
50-
)
51-
def test_get_keypair_fails(keypairs):
52-
mock_ec2_client = MagicMock()
53-
mock_ec2_client.describe_key_pairs.return_value = keypairs
54-
with pytest.raises(
55-
CouldNotDetermineKeyPairException,
56-
match=str(keypairs["KeyPairs"]).replace("[", "\\["),
57-
):
58-
get_keypair(mock_ec2_client)
59-
60-
6139
@pytest.mark.parametrize(
6240
"filesystems",
6341
[
@@ -333,3 +311,54 @@ def test_get_available_subnet(mock_dynamodb_client, mock_claim_subnet):
333311
mock_dynamodb_client(), ec2_subnets["Subnets"], cluster_name
334312
)
335313
assert subnet == "sub-1"
314+
315+
316+
def test_create_keypair():
317+
mock_ec2_client = MagicMock()
318+
mock_ec2_client.describe_key_pairs.side_effect = ClientError(
319+
error_response={"Error": {"Code": 1, "Message": "It failed"}},
320+
operation_name="describe_key_pairs",
321+
)
322+
mock_ec2_client.create_key_pair.return_value = "key created"
323+
vlab_id = "test_vlab"
324+
project_id = "test_project"
325+
tags = [{"Key": "tagkey", "Value": "tagvalue"}]
326+
create_keypair(mock_ec2_client, vlab_id, project_id, tags)
327+
mock_ec2_client.create_key_pair.assert_called_once_with(
328+
KeyName=f"pcluster-{vlab_id}-{project_id}",
329+
TagSpecifications=[{"ResourceType": "key-pair", "Tags": tags}],
330+
)
331+
332+
333+
def test_get_secret():
334+
secret_value = "supersecret"
335+
mock_sm_client = MagicMock()
336+
mock_sm_client.list_secrets.return_value = {"SecretList": [secret_value]}
337+
retrieved_secret = get_secret(mock_sm_client, "mysecret")
338+
assert retrieved_secret == secret_value
339+
340+
341+
def test_get_secret_not_found():
342+
mock_sm_client = MagicMock()
343+
mock_sm_client.list_secrets.return_value = {}
344+
with pytest.raises(RuntimeError):
345+
get_secret(mock_sm_client, "mysecret")
346+
347+
348+
def test_create_secret():
349+
mock_sm_client = MagicMock()
350+
vlab_id = "test_vlab"
351+
project_id = "test_project"
352+
secret_name = "mysecret"
353+
secret_value = "supersecret"
354+
create_secret(mock_sm_client, vlab_id, project_id, secret_name, secret_value)
355+
mock_sm_client.create_secret.assert_called_once_with(
356+
Name=secret_name,
357+
Description=f"SSH Key for cluster for vlab {vlab_id}, project {project_id}",
358+
SecretString=secret_value,
359+
Tags=[
360+
{"Key": VLAB_TAG_KEY, "Value": vlab_id},
361+
{"Key": PROJECT_TAG_KEY, "Value": project_id},
362+
{"Key": BILLING_TAG_KEY, "Value": BILLING_TAG_VALUE},
363+
],
364+
)

hpc_provisioner/tests/test_resource_provisioner.py

+43-11
Original file line numberDiff line numberDiff line change
@@ -156,22 +156,48 @@ def test_get_all_clusters(data):
156156

157157

158158
@patch("hpc_provisioner.handlers.boto3")
159-
def test_post(patched_boto3, post_event):
159+
@pytest.mark.parametrize("key_exists", [True, False])
160+
def test_post(patched_boto3, post_event, key_exists):
161+
test_cluster_name = cluster_name(post_event["vlab_id"], post_event["project_id"])
160162
mock_client = MagicMock()
161163
patched_boto3.client.return_value = mock_client
162-
actual_response = handlers.pcluster_create_request_handler(post_event)
163-
mock_client.invoke_async.assert_called_once_with(
164+
with patch("hpc_provisioner.handlers.create_keypair") as patched_create_keypair:
165+
if key_exists:
166+
patched_create_keypair.return_value = {
167+
"KeyName": test_cluster_name,
168+
}
169+
else:
170+
patched_create_keypair.return_value = {
171+
"KeyMaterial": "secret_stuff",
172+
"KeyName": test_cluster_name,
173+
}
174+
175+
with patch("hpc_provisioner.handlers.create_secret") as patched_create_secret:
176+
patched_create_keypair.return_value = {
177+
"KeyMaterial": "secret_stuff",
178+
"KeyName": test_cluster_name,
179+
}
180+
patched_create_secret.return_value = {"ARN": "secret ARN"}
181+
with patch("hpc_provisioner.handlers.get_secret") as patched_get_secret:
182+
patched_get_secret.return_value = {"ARN": "secret ARN"}
183+
actual_response = handlers.pcluster_create_request_handler(post_event)
184+
mock_client.invoke_async.assert_called_with(
164185
FunctionName="hpc-resource-provisioner-creator",
165186
InvokeArgs=json.dumps(
166-
{"vlab_id": post_event["vlab_id"], "project_id": post_event["project_id"]}
187+
{
188+
"vlab_id": post_event["vlab_id"],
189+
"project_id": post_event["project_id"],
190+
"keyname": f"pcluster-{post_event['vlab_id']}-{post_event['project_id']}",
191+
}
167192
),
168193
)
169194
expected_response = expected_response_template(
170195
text=json.dumps(
171196
{
172197
"cluster": {
173-
"clusterName": cluster_name(post_event["vlab_id"], post_event["project_id"]),
198+
"clusterName": test_cluster_name,
174199
"clusterStatus": "CREATE_REQUEST_RECEIVED",
200+
"private_ssh_key_arn": "secret ARN",
175201
}
176202
}
177203
)
@@ -183,7 +209,10 @@ def test_post(patched_boto3, post_event):
183209
"hpc_provisioner.aws_queries.dynamodb_client",
184210
)
185211
@patch("hpc_provisioner.aws_queries.free_subnet")
186-
def test_delete(patched_free_subnet, patched_dynamodb_client, data, delete_event):
212+
@patch("hpc_provisioner.pcluster_manager.remove_key")
213+
def test_delete(
214+
patched_remove_key, patched_free_subnet, patched_dynamodb_client, data, delete_event
215+
):
187216
mock_client = MagicMock()
188217
patched_dynamodb_client.return_value = mock_client
189218
with patch(
@@ -204,6 +233,7 @@ def test_delete(patched_free_subnet, patched_dynamodb_client, data, delete_event
204233
expected_response = expected_response_template(text=json.dumps(data["deletingCluster"]))
205234
assert actual_response == expected_response
206235
patched_get_registered_subnets.assert_called_once()
236+
patched_remove_key.assert_called_once()
207237
call1 = call(mock_client, "subnet-123")
208238
call2 = call(mock_client, "subnet-234")
209239
patched_free_subnet.assert_has_calls([call1, call2], any_order=True)
@@ -235,7 +265,8 @@ def test_get_internal_server_error(get_event):
235265
@patch(
236266
"hpc_provisioner.aws_queries.dynamodb_client",
237267
)
238-
def test_delete_not_found(patched_dynamodb_client, delete_event):
268+
@patch("hpc_provisioner.pcluster_manager.remove_key")
269+
def test_delete_not_found(patched_remove_key, patched_dynamodb_client, delete_event):
239270
error_message = f"Cluster {delete_event['vlab_id']}-{delete_event['project_id']} does not exist"
240271
with patch(
241272
"hpc_provisioner.pcluster_manager.pc.delete_cluster",
@@ -245,12 +276,14 @@ def test_delete_not_found(patched_dynamodb_client, delete_event):
245276
delete_cluster.assert_called_once()
246277
assert result == {"statusCode": 404, "body": error_message}
247278
patched_dynamodb_client.assert_called_once()
279+
patched_remove_key.assert_called_once()
248280

249281

250282
@patch(
251283
"hpc_provisioner.aws_queries.dynamodb_client",
252284
)
253-
def test_delete_internal_server_error(patched_dynamodb_client, delete_event):
285+
@patch("hpc_provisioner.pcluster_manager.remove_key")
286+
def test_delete_internal_server_error(patched_remove_key, patched_dynamodb_client, delete_event):
254287
with patch(
255288
"hpc_provisioner.pcluster_manager.pc.delete_cluster",
256289
side_effect=RuntimeError,
@@ -259,6 +292,7 @@ def test_delete_internal_server_error(patched_dynamodb_client, delete_event):
259292
patched_delete_cluster.assert_called_once()
260293
assert result == {"statusCode": 500, "body": "<class 'RuntimeError'>"}
261294
patched_dynamodb_client.assert_called_once()
295+
patched_remove_key.assert_called_once()
262296

263297

264298
@patch("hpc_provisioner.pcluster_manager.pc.create_cluster")
@@ -278,9 +312,7 @@ def test_do_create_already_exists(patched_boto3, patched_create_cluster, post_ev
278312
@patch("hpc_provisioner.pcluster_manager.get_available_subnet", return_value="subnet-123")
279313
@patch("hpc_provisioner.pcluster_manager.get_security_group", return_value="sg-123")
280314
@patch("hpc_provisioner.pcluster_manager.get_efs", return_value="efs-123")
281-
@patch("hpc_provisioner.pcluster_manager.get_keypair", return_value="keypair-123")
282315
def test_do_create(
283-
patched_get_keypair,
284316
patched_get_efs,
285317
patched_get_security_group,
286318
patched_get_available_subnet,
@@ -299,13 +331,13 @@ def test_do_create(
299331
"ec2": mock_ec2_client,
300332
"efs": mock_efs_client,
301333
}[x]
334+
post_event["keyname"] = cluster_name(post_event["vlab_id"], post_event["project_id"])
302335
handlers.pcluster_do_create_handler(post_event)
303336
patched_create_cluster.assert_called_once()
304337
assert patched_create_cluster.call_args.kwargs["cluster_name"] == cluster_name(
305338
post_event["vlab_id"], post_event["project_id"]
306339
)
307340
assert "tmp" in patched_create_cluster.call_args.kwargs["cluster_configuration"]
308-
patched_get_keypair.assert_called_once()
309341
patched_get_efs.assert_called_once()
310342
patched_get_security_group.assert_called_once()
311343
patched_get_available_subnet.assert_called_once()

0 commit comments

Comments
 (0)