Skip to content

Commit 680625b

Browse files
committed
Fix dynamic embedded document updates
In order to allow setting new fields that did not previouly exist on dynamic embedded documents, the `lookup_member` functions on `EmbeddedDocumentField` and `GenericEmbeddedDocumentField` return dynamic fields as appropriate. This enables operations like `A.objects(...).update(foo__newfield="bar")`. Resolves MongoEngine#2486
1 parent e51ee40 commit 680625b

File tree

3 files changed

+56
-0
lines changed

3 files changed

+56
-0
lines changed

docs/changelog.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ Development
2424
- BugFix - Calling .clear on a ListField wasn't being marked as changed (and flushed to db upon .save()) #2858
2525
- Improve error message in case a document assigned to a ReferenceField wasn't saved yet #1955
2626
- BugFix - Take `where()` into account when using `.modify()`, as in MyDocument.objects().where("this[field] >= this[otherfield]").modify(field='new') #2044
27+
- BugFix - Unable to add new fields during `QuerySet.update` on `DynamicEmbeddedDocument` fields #2486
2728

2829
Changes in 0.29.0
2930
=================

mongoengine/fields.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,12 @@ def lookup_member(self, member_name):
777777
if field:
778778
return field
779779

780+
# DynamicEmbeddedDocuments should always return a field except for positional operators
781+
if any(
782+
doc_type._dynamic for doc_type in doc_and_subclasses
783+
) and member_name not in ("$", "S"):
784+
return DynamicField(db_field=member_name)
785+
780786
def prepare_query_value(self, op, value):
781787
if value is not None and not isinstance(value, self.document_type):
782788
# Short circuit for special operators, returning them as is
@@ -837,6 +843,12 @@ def lookup_member(self, member_name):
837843
if field:
838844
return field
839845

846+
# DynamicEmbeddedDocuments should always return a field except for positional operators
847+
if any(
848+
document_choice._dynamic for document_choice in document_choices
849+
) and member_name not in ("$", "S"):
850+
return DynamicField(db_field=member_name)
851+
840852
def to_mongo(self, document, use_db_field=True, fields=None):
841853
if document is None:
842854
return None

tests/fields/test_embedded_document_field.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from mongoengine import (
88
Document,
9+
DynamicEmbeddedDocument,
910
EmbeddedDocument,
1011
EmbeddedDocumentField,
1112
EmbeddedDocumentListField,
@@ -224,6 +225,34 @@ class Record(Document):
224225

225226
assert Record.objects(posts__title="foo").count() == 2
226227

228+
def test_update_dynamic_embedded_document_with_new_fields(self):
229+
class Wheel(DynamicEmbeddedDocument):
230+
position = StringField()
231+
232+
class Car(Document):
233+
wheels = EmbeddedDocumentListField(Wheel)
234+
235+
car = Car(
236+
wheels=[
237+
Wheel(position="front-passenger"),
238+
Wheel(position="rear-passenger"),
239+
Wheel(position="front-driver"),
240+
Wheel(position="rear-driver"),
241+
]
242+
).save()
243+
244+
Car.objects(wheels__position="front-driver").update(
245+
set__wheels__S__damaged=True
246+
)
247+
car.reload()
248+
249+
for wheel in car.wheels:
250+
if wheel.position == "front-driver":
251+
assert wheel.damaged
252+
else:
253+
with pytest.raises(AttributeError):
254+
wheel.damaged
255+
227256

228257
class TestGenericEmbeddedDocumentField(MongoDBTestCase):
229258
def test_generic_embedded_document(self):
@@ -455,3 +484,17 @@ class Person(Document):
455484

456485
copied_map_emb_doc = deepcopy(doc.wallet_map)
457486
assert copied_map_emb_doc["test"]._instance is None
487+
488+
def test_update_dynamic_embedded_document_with_new_fields(self):
489+
class Laptop(DynamicEmbeddedDocument):
490+
operating_system = StringField()
491+
492+
class Backpack(Document):
493+
content = GenericEmbeddedDocumentField(choices=[Laptop])
494+
495+
backpack = Backpack(content=Laptop(operating_system="Windows")).save()
496+
497+
Backpack.objects.update(set__content__manufacturer="Acer")
498+
backpack.reload()
499+
500+
assert backpack.content.manufacturer == "Acer"

0 commit comments

Comments
 (0)