Skip to content

Commit ef5957c

Browse files
committed
handle edge case
1 parent 6a6c070 commit ef5957c

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

src/lightning/pytorch/utilities/model_summary/model_summary.py

+1
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ def _add_leftover_params_to_summary(self, arrays: List[Tuple[str, List[str]]], t
343343
layer_summaries["Name"].append(LEFTOVER_PARAMS_NAME)
344344
layer_summaries["Type"].append(NOT_APPLICABLE)
345345
layer_summaries["Params"].append(get_human_readable_count(total_leftover_params))
346+
layer_summaries["Mode"].append(NOT_APPLICABLE)
346347
if "In sizes" in layer_summaries:
347348
layer_summaries["In sizes"].append(NOT_APPLICABLE)
348349
if "Out sizes" in layer_summaries:

tests/tests_pytorch/utilities/test_model_summary.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -443,10 +443,16 @@ def test_summary_training_mode():
443443
assert summary_data["Mode"] == [
444444
"train", # branch1
445445
"eval", # branch2
446-
"train" # head
446+
"train", # head
447447
]
448448

449449
summary = summarize(model, max_depth=-1)
450450
expected_eval = {"branch1.1.0", "branch2"}
451451
for name, layer_summary in summary._layer_summary.items():
452452
assert (name in expected_eval) == (not layer_summary.training)
453+
454+
# A model with params not belonging to a layer
455+
model = NonLayerParamsModel()
456+
model.layer.eval()
457+
summary_data = OrderedDict(summarize(model)._get_summary_data())
458+
assert summary_data["Mode"] == ["eval", "n/a"]

0 commit comments

Comments
 (0)