Skip to content

Commit 8563766

Browse files
authored
Merge pull request #21858 from shaod2/py-cp-29
Cherrypick Pure Python recursion limit enforcement to 29.x
2 parents 69cca9b + 05ba1a8 commit 8563766

File tree

5 files changed

+190
-33
lines changed

5 files changed

+190
-33
lines changed

python/google/protobuf/internal/decoder.py

Lines changed: 108 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,10 @@ def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default,
184184
clear_if_default=False):
185185
if is_packed:
186186
local_DecodeVarint = _DecodeVarint
187-
def DecodePackedField(buffer, pos, end, message, field_dict):
187+
def DecodePackedField(
188+
buffer, pos, end, message, field_dict, current_depth=0
189+
):
190+
del current_depth # unused
188191
value = field_dict.get(key)
189192
if value is None:
190193
value = field_dict.setdefault(key, new_default(message))
@@ -199,11 +202,15 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
199202
del value[-1] # Discard corrupt value.
200203
raise _DecodeError('Packed element was truncated.')
201204
return pos
205+
202206
return DecodePackedField
203207
elif is_repeated:
204208
tag_bytes = encoder.TagBytes(field_number, wire_type)
205209
tag_len = len(tag_bytes)
206-
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
210+
def DecodeRepeatedField(
211+
buffer, pos, end, message, field_dict, current_depth=0
212+
):
213+
del current_depth # unused
207214
value = field_dict.get(key)
208215
if value is None:
209216
value = field_dict.setdefault(key, new_default(message))
@@ -218,9 +225,12 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
218225
if new_pos > end:
219226
raise _DecodeError('Truncated message.')
220227
return new_pos
228+
221229
return DecodeRepeatedField
222230
else:
223-
def DecodeField(buffer, pos, end, message, field_dict):
231+
232+
def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
233+
del current_depth # unused
224234
(new_value, pos) = decode_value(buffer, pos)
225235
if pos > end:
226236
raise _DecodeError('Truncated message.')
@@ -229,6 +239,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
229239
else:
230240
field_dict[key] = new_value
231241
return pos
242+
232243
return DecodeField
233244

234245
return SpecificDecoder
@@ -364,7 +375,9 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
364375
enum_type = key.enum_type
365376
if is_packed:
366377
local_DecodeVarint = _DecodeVarint
367-
def DecodePackedField(buffer, pos, end, message, field_dict):
378+
def DecodePackedField(
379+
buffer, pos, end, message, field_dict, current_depth=0
380+
):
368381
"""Decode serialized packed enum to its value and a new position.
369382
370383
Args:
@@ -377,6 +390,7 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
377390
Returns:
378391
int, new position in serialized data.
379392
"""
393+
del current_depth # unused
380394
value = field_dict.get(key)
381395
if value is None:
382396
value = field_dict.setdefault(key, new_default(message))
@@ -407,11 +421,14 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
407421
# pylint: enable=protected-access
408422
raise _DecodeError('Packed element was truncated.')
409423
return pos
424+
410425
return DecodePackedField
411426
elif is_repeated:
412427
tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
413428
tag_len = len(tag_bytes)
414-
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
429+
def DecodeRepeatedField(
430+
buffer, pos, end, message, field_dict, current_depth=0
431+
):
415432
"""Decode serialized repeated enum to its value and a new position.
416433
417434
Args:
@@ -424,6 +441,7 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
424441
Returns:
425442
int, new position in serialized data.
426443
"""
444+
del current_depth # unused
427445
value = field_dict.get(key)
428446
if value is None:
429447
value = field_dict.setdefault(key, new_default(message))
@@ -446,9 +464,11 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
446464
if new_pos > end:
447465
raise _DecodeError('Truncated message.')
448466
return new_pos
467+
449468
return DecodeRepeatedField
450469
else:
451-
def DecodeField(buffer, pos, end, message, field_dict):
470+
471+
def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
452472
"""Decode serialized repeated enum to its value and a new position.
453473
454474
Args:
@@ -461,6 +481,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
461481
Returns:
462482
int, new position in serialized data.
463483
"""
484+
del current_depth # unused
464485
value_start_pos = pos
465486
(enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
466487
if pos > end:
@@ -480,6 +501,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
480501
(tag_bytes, buffer[value_start_pos:pos].tobytes()))
481502
# pylint: enable=protected-access
482503
return pos
504+
483505
return DecodeField
484506

485507

@@ -538,7 +560,10 @@ def _ConvertToUnicode(memview):
538560
tag_bytes = encoder.TagBytes(field_number,
539561
wire_format.WIRETYPE_LENGTH_DELIMITED)
540562
tag_len = len(tag_bytes)
541-
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
563+
def DecodeRepeatedField(
564+
buffer, pos, end, message, field_dict, current_depth=0
565+
):
566+
del current_depth # unused
542567
value = field_dict.get(key)
543568
if value is None:
544569
value = field_dict.setdefault(key, new_default(message))
@@ -553,9 +578,12 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
553578
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
554579
# Prediction failed. Return.
555580
return new_pos
581+
556582
return DecodeRepeatedField
557583
else:
558-
def DecodeField(buffer, pos, end, message, field_dict):
584+
585+
def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
586+
del current_depth # unused
559587
(size, pos) = local_DecodeVarint(buffer, pos)
560588
new_pos = pos + size
561589
if new_pos > end:
@@ -565,6 +593,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
565593
else:
566594
field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
567595
return new_pos
596+
568597
return DecodeField
569598

570599

@@ -579,7 +608,10 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
579608
tag_bytes = encoder.TagBytes(field_number,
580609
wire_format.WIRETYPE_LENGTH_DELIMITED)
581610
tag_len = len(tag_bytes)
582-
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
611+
def DecodeRepeatedField(
612+
buffer, pos, end, message, field_dict, current_depth=0
613+
):
614+
del current_depth # unused
583615
value = field_dict.get(key)
584616
if value is None:
585617
value = field_dict.setdefault(key, new_default(message))
@@ -594,9 +626,12 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
594626
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
595627
# Prediction failed. Return.
596628
return new_pos
629+
597630
return DecodeRepeatedField
598631
else:
599-
def DecodeField(buffer, pos, end, message, field_dict):
632+
633+
def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
634+
del current_depth # unused
600635
(size, pos) = local_DecodeVarint(buffer, pos)
601636
new_pos = pos + size
602637
if new_pos > end:
@@ -606,6 +641,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
606641
else:
607642
field_dict[key] = buffer[pos:new_pos].tobytes()
608643
return new_pos
644+
609645
return DecodeField
610646

611647

@@ -621,7 +657,9 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
621657
tag_bytes = encoder.TagBytes(field_number,
622658
wire_format.WIRETYPE_START_GROUP)
623659
tag_len = len(tag_bytes)
624-
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
660+
def DecodeRepeatedField(
661+
buffer, pos, end, message, field_dict, current_depth=0
662+
):
625663
value = field_dict.get(key)
626664
if value is None:
627665
value = field_dict.setdefault(key, new_default(message))
@@ -630,7 +668,13 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
630668
if value is None:
631669
value = field_dict.setdefault(key, new_default(message))
632670
# Read sub-message.
633-
pos = value.add()._InternalParse(buffer, pos, end)
671+
current_depth += 1
672+
if current_depth > _recursion_limit:
673+
raise _DecodeError(
674+
'Error parsing message: too many levels of nesting.'
675+
)
676+
pos = value.add()._InternalParse(buffer, pos, end, current_depth)
677+
current_depth -= 1
634678
# Read end tag.
635679
new_pos = pos+end_tag_len
636680
if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
@@ -640,19 +684,26 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
640684
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
641685
# Prediction failed. Return.
642686
return new_pos
687+
643688
return DecodeRepeatedField
644689
else:
645-
def DecodeField(buffer, pos, end, message, field_dict):
690+
691+
def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
646692
value = field_dict.get(key)
647693
if value is None:
648694
value = field_dict.setdefault(key, new_default(message))
649695
# Read sub-message.
650-
pos = value._InternalParse(buffer, pos, end)
696+
current_depth += 1
697+
if current_depth > _recursion_limit:
698+
raise _DecodeError('Error parsing message: too many levels of nesting.')
699+
pos = value._InternalParse(buffer, pos, end, current_depth)
700+
current_depth -= 1
651701
# Read end tag.
652702
new_pos = pos+end_tag_len
653703
if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
654704
raise _DecodeError('Missing group end tag.')
655705
return new_pos
706+
656707
return DecodeField
657708

658709

@@ -666,7 +717,9 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
666717
tag_bytes = encoder.TagBytes(field_number,
667718
wire_format.WIRETYPE_LENGTH_DELIMITED)
668719
tag_len = len(tag_bytes)
669-
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
720+
def DecodeRepeatedField(
721+
buffer, pos, end, message, field_dict, current_depth=0
722+
):
670723
value = field_dict.get(key)
671724
if value is None:
672725
value = field_dict.setdefault(key, new_default(message))
@@ -677,18 +730,29 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
677730
if new_pos > end:
678731
raise _DecodeError('Truncated message.')
679732
# Read sub-message.
680-
if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
733+
current_depth += 1
734+
if current_depth > _recursion_limit:
735+
raise _DecodeError(
736+
'Error parsing message: too many levels of nesting.'
737+
)
738+
if (
739+
value.add()._InternalParse(buffer, pos, new_pos, current_depth)
740+
!= new_pos
741+
):
681742
# The only reason _InternalParse would return early is if it
682743
# encountered an end-group tag.
683744
raise _DecodeError('Unexpected end-group tag.')
745+
current_depth -= 1
684746
# Predict that the next tag is another copy of the same repeated field.
685747
pos = new_pos + tag_len
686748
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
687749
# Prediction failed. Return.
688750
return new_pos
751+
689752
return DecodeRepeatedField
690753
else:
691-
def DecodeField(buffer, pos, end, message, field_dict):
754+
755+
def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
692756
value = field_dict.get(key)
693757
if value is None:
694758
value = field_dict.setdefault(key, new_default(message))
@@ -698,11 +762,16 @@ def DecodeField(buffer, pos, end, message, field_dict):
698762
if new_pos > end:
699763
raise _DecodeError('Truncated message.')
700764
# Read sub-message.
701-
if value._InternalParse(buffer, pos, new_pos) != new_pos:
765+
current_depth += 1
766+
if current_depth > _recursion_limit:
767+
raise _DecodeError('Error parsing message: too many levels of nesting.')
768+
if value._InternalParse(buffer, pos, new_pos, current_depth) != new_pos:
702769
# The only reason _InternalParse would return early is if it encountered
703770
# an end-group tag.
704771
raise _DecodeError('Unexpected end-group tag.')
772+
current_depth -= 1
705773
return new_pos
774+
706775
return DecodeField
707776

708777

@@ -851,7 +920,8 @@ def MapDecoder(field_descriptor, new_default, is_message_map):
851920
# Can't read _concrete_class yet; might not be initialized.
852921
message_type = field_descriptor.message_type
853922

854-
def DecodeMap(buffer, pos, end, message, field_dict):
923+
def DecodeMap(buffer, pos, end, message, field_dict, current_depth=0):
924+
del current_depth # Unused.
855925
submsg = message_type._concrete_class()
856926
value = field_dict.get(key)
857927
if value is None:
@@ -934,7 +1004,16 @@ def _SkipGroup(buffer, pos, end):
9341004
pos = new_pos
9351005

9361006

937-
def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
1007+
DEFAULT_RECURSION_LIMIT = 100
1008+
_recursion_limit = DEFAULT_RECURSION_LIMIT
1009+
1010+
1011+
def SetRecursionLimit(new_limit):
1012+
global _recursion_limit
1013+
_recursion_limit = new_limit
1014+
1015+
1016+
def _DecodeUnknownFieldSet(buffer, pos, end_pos=None, current_depth=0):
9381017
"""Decode UnknownFieldSet. Returns the UnknownFieldSet and new position."""
9391018

9401019
unknown_field_set = containers.UnknownFieldSet()
@@ -944,14 +1023,16 @@ def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
9441023
field_number, wire_type = wire_format.UnpackTag(tag)
9451024
if wire_type == wire_format.WIRETYPE_END_GROUP:
9461025
break
947-
(data, pos) = _DecodeUnknownField(buffer, pos, wire_type)
1026+
(data, pos) = _DecodeUnknownField(buffer, pos, wire_type, current_depth)
9481027
# pylint: disable=protected-access
9491028
unknown_field_set._add(field_number, wire_type, data)
9501029

9511030
return (unknown_field_set, pos)
9521031

9531032

954-
def _DecodeUnknownField(buffer, pos, wire_type):
1033+
def _DecodeUnknownField(
1034+
buffer, pos, wire_type, current_depth=0
1035+
):
9551036
"""Decode a unknown field. Returns the UnknownField and new position."""
9561037

9571038
if wire_type == wire_format.WIRETYPE_VARINT:
@@ -965,7 +1046,11 @@ def _DecodeUnknownField(buffer, pos, wire_type):
9651046
data = buffer[pos:pos+size].tobytes()
9661047
pos += size
9671048
elif wire_type == wire_format.WIRETYPE_START_GROUP:
968-
(data, pos) = _DecodeUnknownFieldSet(buffer, pos)
1049+
current_depth += 1
1050+
if current_depth >= _recursion_limit:
1051+
raise _DecodeError('Error parsing message: too many levels of nesting.')
1052+
data, pos = _DecodeUnknownFieldSet(buffer, pos, None, current_depth)
1053+
current_depth -= 1
9691054
elif wire_type == wire_format.WIRETYPE_END_GROUP:
9701055
return (0, -1)
9711056
else:

python/google/protobuf/internal/decoder_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
import io
1212
import unittest
1313

14+
from google.protobuf import message
1415
from google.protobuf.internal import decoder
1516
from google.protobuf.internal import testing_refleaks
17+
from google.protobuf.internal import wire_format
1618

1719

1820
_INPUT_BYTES = b'\x84r\x12'
@@ -52,6 +54,18 @@ def test_decode_varint_bytesio_empty(self):
5254
size = decoder._DecodeVarint(input_io)
5355
self.assertEqual(size, None)
5456

57+
def test_decode_unknown_group_field_too_many_levels(self):
58+
data = memoryview(b'\023' * 5_000_000)
59+
self.assertRaisesRegex(
60+
message.DecodeError,
61+
'Error parsing message',
62+
decoder._DecodeUnknownField,
63+
data,
64+
1,
65+
wire_format.WIRETYPE_START_GROUP,
66+
1
67+
)
68+
5569

5670
if __name__ == '__main__':
5771
unittest.main()

0 commit comments

Comments
 (0)