@@ -184,7 +184,10 @@ def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default,
184
184
clear_if_default = False ):
185
185
if is_packed :
186
186
local_DecodeVarint = _DecodeVarint
187
- def DecodePackedField (buffer , pos , end , message , field_dict ):
187
+ def DecodePackedField (
188
+ buffer , pos , end , message , field_dict , current_depth = 0
189
+ ):
190
+ del current_depth # unused
188
191
value = field_dict .get (key )
189
192
if value is None :
190
193
value = field_dict .setdefault (key , new_default (message ))
@@ -199,11 +202,15 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
199
202
del value [- 1 ] # Discard corrupt value.
200
203
raise _DecodeError ('Packed element was truncated.' )
201
204
return pos
205
+
202
206
return DecodePackedField
203
207
elif is_repeated :
204
208
tag_bytes = encoder .TagBytes (field_number , wire_type )
205
209
tag_len = len (tag_bytes )
206
- def DecodeRepeatedField (buffer , pos , end , message , field_dict ):
210
+ def DecodeRepeatedField (
211
+ buffer , pos , end , message , field_dict , current_depth = 0
212
+ ):
213
+ del current_depth # unused
207
214
value = field_dict .get (key )
208
215
if value is None :
209
216
value = field_dict .setdefault (key , new_default (message ))
@@ -218,9 +225,12 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
218
225
if new_pos > end :
219
226
raise _DecodeError ('Truncated message.' )
220
227
return new_pos
228
+
221
229
return DecodeRepeatedField
222
230
else :
223
- def DecodeField (buffer , pos , end , message , field_dict ):
231
+
232
+ def DecodeField (buffer , pos , end , message , field_dict , current_depth = 0 ):
233
+ del current_depth # unused
224
234
(new_value , pos ) = decode_value (buffer , pos )
225
235
if pos > end :
226
236
raise _DecodeError ('Truncated message.' )
@@ -229,6 +239,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
229
239
else :
230
240
field_dict [key ] = new_value
231
241
return pos
242
+
232
243
return DecodeField
233
244
234
245
return SpecificDecoder
@@ -364,7 +375,9 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
364
375
enum_type = key .enum_type
365
376
if is_packed :
366
377
local_DecodeVarint = _DecodeVarint
367
- def DecodePackedField (buffer , pos , end , message , field_dict ):
378
+ def DecodePackedField (
379
+ buffer , pos , end , message , field_dict , current_depth = 0
380
+ ):
368
381
"""Decode serialized packed enum to its value and a new position.
369
382
370
383
Args:
@@ -377,6 +390,7 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
377
390
Returns:
378
391
int, new position in serialized data.
379
392
"""
393
+ del current_depth # unused
380
394
value = field_dict .get (key )
381
395
if value is None :
382
396
value = field_dict .setdefault (key , new_default (message ))
@@ -407,11 +421,14 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
407
421
# pylint: enable=protected-access
408
422
raise _DecodeError ('Packed element was truncated.' )
409
423
return pos
424
+
410
425
return DecodePackedField
411
426
elif is_repeated :
412
427
tag_bytes = encoder .TagBytes (field_number , wire_format .WIRETYPE_VARINT )
413
428
tag_len = len (tag_bytes )
414
- def DecodeRepeatedField (buffer , pos , end , message , field_dict ):
429
+ def DecodeRepeatedField (
430
+ buffer , pos , end , message , field_dict , current_depth = 0
431
+ ):
415
432
"""Decode serialized repeated enum to its value and a new position.
416
433
417
434
Args:
@@ -424,6 +441,7 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
424
441
Returns:
425
442
int, new position in serialized data.
426
443
"""
444
+ del current_depth # unused
427
445
value = field_dict .get (key )
428
446
if value is None :
429
447
value = field_dict .setdefault (key , new_default (message ))
@@ -446,9 +464,11 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
446
464
if new_pos > end :
447
465
raise _DecodeError ('Truncated message.' )
448
466
return new_pos
467
+
449
468
return DecodeRepeatedField
450
469
else :
451
- def DecodeField (buffer , pos , end , message , field_dict ):
470
+
471
+ def DecodeField (buffer , pos , end , message , field_dict , current_depth = 0 ):
452
472
"""Decode serialized repeated enum to its value and a new position.
453
473
454
474
Args:
@@ -461,6 +481,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
461
481
Returns:
462
482
int, new position in serialized data.
463
483
"""
484
+ del current_depth # unused
464
485
value_start_pos = pos
465
486
(enum_value , pos ) = _DecodeSignedVarint32 (buffer , pos )
466
487
if pos > end :
@@ -480,6 +501,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
480
501
(tag_bytes , buffer [value_start_pos :pos ].tobytes ()))
481
502
# pylint: enable=protected-access
482
503
return pos
504
+
483
505
return DecodeField
484
506
485
507
@@ -538,7 +560,10 @@ def _ConvertToUnicode(memview):
538
560
tag_bytes = encoder .TagBytes (field_number ,
539
561
wire_format .WIRETYPE_LENGTH_DELIMITED )
540
562
tag_len = len (tag_bytes )
541
- def DecodeRepeatedField (buffer , pos , end , message , field_dict ):
563
+ def DecodeRepeatedField (
564
+ buffer , pos , end , message , field_dict , current_depth = 0
565
+ ):
566
+ del current_depth # unused
542
567
value = field_dict .get (key )
543
568
if value is None :
544
569
value = field_dict .setdefault (key , new_default (message ))
@@ -553,9 +578,12 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
553
578
if buffer [new_pos :pos ] != tag_bytes or new_pos == end :
554
579
# Prediction failed. Return.
555
580
return new_pos
581
+
556
582
return DecodeRepeatedField
557
583
else :
558
- def DecodeField (buffer , pos , end , message , field_dict ):
584
+
585
+ def DecodeField (buffer , pos , end , message , field_dict , current_depth = 0 ):
586
+ del current_depth # unused
559
587
(size , pos ) = local_DecodeVarint (buffer , pos )
560
588
new_pos = pos + size
561
589
if new_pos > end :
@@ -565,6 +593,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
565
593
else :
566
594
field_dict [key ] = _ConvertToUnicode (buffer [pos :new_pos ])
567
595
return new_pos
596
+
568
597
return DecodeField
569
598
570
599
@@ -579,7 +608,10 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
579
608
tag_bytes = encoder .TagBytes (field_number ,
580
609
wire_format .WIRETYPE_LENGTH_DELIMITED )
581
610
tag_len = len (tag_bytes )
582
- def DecodeRepeatedField (buffer , pos , end , message , field_dict ):
611
+ def DecodeRepeatedField (
612
+ buffer , pos , end , message , field_dict , current_depth = 0
613
+ ):
614
+ del current_depth # unused
583
615
value = field_dict .get (key )
584
616
if value is None :
585
617
value = field_dict .setdefault (key , new_default (message ))
@@ -594,9 +626,12 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
594
626
if buffer [new_pos :pos ] != tag_bytes or new_pos == end :
595
627
# Prediction failed. Return.
596
628
return new_pos
629
+
597
630
return DecodeRepeatedField
598
631
else :
599
- def DecodeField (buffer , pos , end , message , field_dict ):
632
+
633
+ def DecodeField (buffer , pos , end , message , field_dict , current_depth = 0 ):
634
+ del current_depth # unused
600
635
(size , pos ) = local_DecodeVarint (buffer , pos )
601
636
new_pos = pos + size
602
637
if new_pos > end :
@@ -606,6 +641,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
606
641
else :
607
642
field_dict [key ] = buffer [pos :new_pos ].tobytes ()
608
643
return new_pos
644
+
609
645
return DecodeField
610
646
611
647
@@ -621,7 +657,9 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
621
657
tag_bytes = encoder .TagBytes (field_number ,
622
658
wire_format .WIRETYPE_START_GROUP )
623
659
tag_len = len (tag_bytes )
624
- def DecodeRepeatedField (buffer , pos , end , message , field_dict ):
660
+ def DecodeRepeatedField (
661
+ buffer , pos , end , message , field_dict , current_depth = 0
662
+ ):
625
663
value = field_dict .get (key )
626
664
if value is None :
627
665
value = field_dict .setdefault (key , new_default (message ))
@@ -630,7 +668,13 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
630
668
if value is None :
631
669
value = field_dict .setdefault (key , new_default (message ))
632
670
# Read sub-message.
633
- pos = value .add ()._InternalParse (buffer , pos , end )
671
+ current_depth += 1
672
+ if current_depth > _recursion_limit :
673
+ raise _DecodeError (
674
+ 'Error parsing message: too many levels of nesting.'
675
+ )
676
+ pos = value .add ()._InternalParse (buffer , pos , end , current_depth )
677
+ current_depth -= 1
634
678
# Read end tag.
635
679
new_pos = pos + end_tag_len
636
680
if buffer [pos :new_pos ] != end_tag_bytes or new_pos > end :
@@ -640,19 +684,26 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
640
684
if buffer [new_pos :pos ] != tag_bytes or new_pos == end :
641
685
# Prediction failed. Return.
642
686
return new_pos
687
+
643
688
return DecodeRepeatedField
644
689
else :
645
- def DecodeField (buffer , pos , end , message , field_dict ):
690
+
691
+ def DecodeField (buffer , pos , end , message , field_dict , current_depth = 0 ):
646
692
value = field_dict .get (key )
647
693
if value is None :
648
694
value = field_dict .setdefault (key , new_default (message ))
649
695
# Read sub-message.
650
- pos = value ._InternalParse (buffer , pos , end )
696
+ current_depth += 1
697
+ if current_depth > _recursion_limit :
698
+ raise _DecodeError ('Error parsing message: too many levels of nesting.' )
699
+ pos = value ._InternalParse (buffer , pos , end , current_depth )
700
+ current_depth -= 1
651
701
# Read end tag.
652
702
new_pos = pos + end_tag_len
653
703
if buffer [pos :new_pos ] != end_tag_bytes or new_pos > end :
654
704
raise _DecodeError ('Missing group end tag.' )
655
705
return new_pos
706
+
656
707
return DecodeField
657
708
658
709
@@ -666,7 +717,9 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
666
717
tag_bytes = encoder .TagBytes (field_number ,
667
718
wire_format .WIRETYPE_LENGTH_DELIMITED )
668
719
tag_len = len (tag_bytes )
669
- def DecodeRepeatedField (buffer , pos , end , message , field_dict ):
720
+ def DecodeRepeatedField (
721
+ buffer , pos , end , message , field_dict , current_depth = 0
722
+ ):
670
723
value = field_dict .get (key )
671
724
if value is None :
672
725
value = field_dict .setdefault (key , new_default (message ))
@@ -677,18 +730,29 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
677
730
if new_pos > end :
678
731
raise _DecodeError ('Truncated message.' )
679
732
# Read sub-message.
680
- if value .add ()._InternalParse (buffer , pos , new_pos ) != new_pos :
733
+ current_depth += 1
734
+ if current_depth > _recursion_limit :
735
+ raise _DecodeError (
736
+ 'Error parsing message: too many levels of nesting.'
737
+ )
738
+ if (
739
+ value .add ()._InternalParse (buffer , pos , new_pos , current_depth )
740
+ != new_pos
741
+ ):
681
742
# The only reason _InternalParse would return early is if it
682
743
# encountered an end-group tag.
683
744
raise _DecodeError ('Unexpected end-group tag.' )
745
+ current_depth -= 1
684
746
# Predict that the next tag is another copy of the same repeated field.
685
747
pos = new_pos + tag_len
686
748
if buffer [new_pos :pos ] != tag_bytes or new_pos == end :
687
749
# Prediction failed. Return.
688
750
return new_pos
751
+
689
752
return DecodeRepeatedField
690
753
else :
691
- def DecodeField (buffer , pos , end , message , field_dict ):
754
+
755
+ def DecodeField (buffer , pos , end , message , field_dict , current_depth = 0 ):
692
756
value = field_dict .get (key )
693
757
if value is None :
694
758
value = field_dict .setdefault (key , new_default (message ))
@@ -698,11 +762,16 @@ def DecodeField(buffer, pos, end, message, field_dict):
698
762
if new_pos > end :
699
763
raise _DecodeError ('Truncated message.' )
700
764
# Read sub-message.
701
- if value ._InternalParse (buffer , pos , new_pos ) != new_pos :
765
+ current_depth += 1
766
+ if current_depth > _recursion_limit :
767
+ raise _DecodeError ('Error parsing message: too many levels of nesting.' )
768
+ if value ._InternalParse (buffer , pos , new_pos , current_depth ) != new_pos :
702
769
# The only reason _InternalParse would return early is if it encountered
703
770
# an end-group tag.
704
771
raise _DecodeError ('Unexpected end-group tag.' )
772
+ current_depth -= 1
705
773
return new_pos
774
+
706
775
return DecodeField
707
776
708
777
@@ -851,7 +920,8 @@ def MapDecoder(field_descriptor, new_default, is_message_map):
851
920
# Can't read _concrete_class yet; might not be initialized.
852
921
message_type = field_descriptor .message_type
853
922
854
- def DecodeMap (buffer , pos , end , message , field_dict ):
923
+ def DecodeMap (buffer , pos , end , message , field_dict , current_depth = 0 ):
924
+ del current_depth # Unused.
855
925
submsg = message_type ._concrete_class ()
856
926
value = field_dict .get (key )
857
927
if value is None :
@@ -934,7 +1004,16 @@ def _SkipGroup(buffer, pos, end):
934
1004
pos = new_pos
935
1005
936
1006
937
- def _DecodeUnknownFieldSet (buffer , pos , end_pos = None ):
1007
+ DEFAULT_RECURSION_LIMIT = 100
1008
+ _recursion_limit = DEFAULT_RECURSION_LIMIT
1009
+
1010
+
1011
+ def SetRecursionLimit (new_limit ):
1012
+ global _recursion_limit
1013
+ _recursion_limit = new_limit
1014
+
1015
+
1016
+ def _DecodeUnknownFieldSet (buffer , pos , end_pos = None , current_depth = 0 ):
938
1017
"""Decode UnknownFieldSet. Returns the UnknownFieldSet and new position."""
939
1018
940
1019
unknown_field_set = containers .UnknownFieldSet ()
@@ -944,14 +1023,16 @@ def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
944
1023
field_number , wire_type = wire_format .UnpackTag (tag )
945
1024
if wire_type == wire_format .WIRETYPE_END_GROUP :
946
1025
break
947
- (data , pos ) = _DecodeUnknownField (buffer , pos , wire_type )
1026
+ (data , pos ) = _DecodeUnknownField (buffer , pos , wire_type , current_depth )
948
1027
# pylint: disable=protected-access
949
1028
unknown_field_set ._add (field_number , wire_type , data )
950
1029
951
1030
return (unknown_field_set , pos )
952
1031
953
1032
954
- def _DecodeUnknownField (buffer , pos , wire_type ):
1033
+ def _DecodeUnknownField (
1034
+ buffer , pos , wire_type , current_depth = 0
1035
+ ):
955
1036
"""Decode a unknown field. Returns the UnknownField and new position."""
956
1037
957
1038
if wire_type == wire_format .WIRETYPE_VARINT :
@@ -965,7 +1046,11 @@ def _DecodeUnknownField(buffer, pos, wire_type):
965
1046
data = buffer [pos :pos + size ].tobytes ()
966
1047
pos += size
967
1048
elif wire_type == wire_format .WIRETYPE_START_GROUP :
968
- (data , pos ) = _DecodeUnknownFieldSet (buffer , pos )
1049
+ current_depth += 1
1050
+ if current_depth >= _recursion_limit :
1051
+ raise _DecodeError ('Error parsing message: too many levels of nesting.' )
1052
+ data , pos = _DecodeUnknownFieldSet (buffer , pos , None , current_depth )
1053
+ current_depth -= 1
969
1054
elif wire_type == wire_format .WIRETYPE_END_GROUP :
970
1055
return (0 , - 1 )
971
1056
else :
0 commit comments