Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make it possible to use interface types in STL containers relying on comparisons #552

Merged
merged 3 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions python/podio_gen/cpp_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,14 @@ def __init__(
self.root_schema_component_names = set()
self.root_schema_datatype_names = set()
self.root_schema_iorules = set()
# a map of datatypes that are used in interfaces populated by pre_process
self.types_in_interfaces = {}

def pre_process(self):
"""The necessary specific pre-processing for cpp code generation"""
self._pre_process_schema_evolution()
self.types_in_interfaces = self._invert_interfaces()

return {}

def post_process(self, _):
Expand Down Expand Up @@ -120,6 +124,7 @@ def do_process_component(self, name, component):
def do_process_datatype(self, name, datatype):
"""Do the cpp specific processing of a datatype"""
datatype["includes_data"] = self._get_member_includes(datatype["Members"])
datatype["using_interface_types"] = self.types_in_interfaces.get(name, [])
self._preprocess_for_class(datatype)
self._preprocess_for_obj(datatype)
self._preprocess_for_collection(datatype)
Expand Down Expand Up @@ -201,7 +206,7 @@ def print_report(self):
def _preprocess_for_class(self, datatype):
"""Do the preprocessing that is necessary for the classes and Mutable classes"""
includes = set(datatype["includes_data"])
fwd_declarations = {}
fwd_declarations = defaultdict(list)
includes_cc = set()

for member in datatype["Members"]:
Expand All @@ -212,10 +217,8 @@ def _preprocess_for_class(self, datatype):
if self._is_interface(relation.full_type):
relation.interface_types = self.datamodel.interfaces[relation.full_type]["Types"]
if self._needs_include(relation.full_type):
if relation.namespace not in fwd_declarations:
fwd_declarations[relation.namespace] = []
fwd_declarations[relation.namespace].append(relation.bare_type)
fwd_declarations[relation.namespace].append("Mutable" + relation.bare_type)
fwd_declarations[relation.namespace].append(f"Mutable{relation.bare_type}")
includes_cc.add(self._build_include(relation))

if datatype["VectorMembers"] or datatype["OneToManyRelations"]:
Expand Down Expand Up @@ -246,6 +249,13 @@ def _preprocess_for_class(self, datatype):
except KeyError:
pass

# Make sure that all using interface types are properly forward declared
# to make it possible to declare them as friends so that they can access
# internals more easily
for interface in datatype["using_interface_types"]:
if_type = DataType(interface)
fwd_declarations[if_type.namespace].append(if_type.bare_type)

datatype["includes"] = self._sort_includes(includes)
datatype["includes_cc"] = self._sort_includes(includes_cc)
datatype["forward_declarations"] = fwd_declarations
Expand Down Expand Up @@ -381,6 +391,21 @@ def _pre_process_schema_evolution(self):
# add whatever is relevant to our ROOT schema evolution
self.root_schema_dict.setdefault(item.klassname, []).append(item)

def _invert_interfaces(self):
"""'Invert' the interfaces to have a mapping of types and their usage in
interfaces.

This is necessary to declare the interface types as friends of the
classes they wrap in order to more easily access some internals.
"""
types_in_interfaces = defaultdict(list)
for name, interface in self.datamodel.interfaces.items():
print(f"preprocessing interface {name}")
for if_type in interface["Types"]:
types_in_interfaces[if_type.full_type].append(name)

return types_in_interfaces

def _prepare_iorules(self):
"""Prepare the IORules to be put in the Reflex dictionary"""
for type_name, schema_changes in self.root_schema_dict.items():
Expand Down
11 changes: 10 additions & 1 deletion python/templates/Interface.h.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class {{ class.bare_type }} {
{{ macros.member_getters_concept(Members, use_get_syntax) }}
virtual const std::type_info& typeInfo() const = 0;
virtual bool equal(const Concept* rhs) const = 0;
virtual const void* objAddress() const = 0;
};

template<typename ValueT>
Expand All @@ -65,7 +66,7 @@ class {{ class.bare_type }} {

void unlink() final { m_value.unlink(); }
bool isAvailable() const final { return m_value.isAvailable(); }
podio::ObjectID getObjectID() const { return m_value.getObjectID(); }
podio::ObjectID getObjectID() const final { return m_value.getObjectID(); }

const std::type_info& typeInfo() const final { return typeid(ValueT); }

Expand All @@ -76,6 +77,10 @@ class {{ class.bare_type }} {
return false;
}

const void* objAddress() const final {
return m_value.m_obj.get();
}

{{ macros.member_getters_model(Members, use_get_syntax) }}

ValueT m_value{};
Expand Down Expand Up @@ -144,6 +149,10 @@ public:
return !(lhs == rhs);
}

friend bool operator<(const {{ class.bare_type }}& lhs, const {{ class.bare_type }}& rhs) {
return lhs.m_self->objAddress() < rhs.m_self->objAddress();
}

{{ macros.member_getters(Members, use_get_syntax) }}

friend std::ostream& operator<<(std::ostream& os, const {{ class.bare_type }}& value) {
Expand Down
3 changes: 3 additions & 0 deletions python/templates/Object.h.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class {{ class.bare_type }} {
friend class {{ class.bare_type }}Collection;
friend class {{ class.full_type }}CollectionData;
friend class {{ class.bare_type }}CollectionIterator;
{% for interface in using_interface_types %}
friend class {{ interface }};
{% endfor %}

public:
using mutable_type = Mutable{{ class.bare_type }};
Expand Down
21 changes: 21 additions & 0 deletions tests/unittests/interface_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "datamodel/ExampleClusterCollection.h"
#include "datamodel/ExampleHitCollection.h"
#include "datamodel/TypeWithEnergy.h"

#include <map>
#include <stdexcept>

TEST_CASE("InterfaceTypes basic functionality", "[interface-types][basics]") {
Expand Down Expand Up @@ -45,6 +47,25 @@ TEST_CASE("InterfaceTypes basic functionality", "[interface-types][basics]") {
REQUIRE(wrapper1.id() == podio::ObjectID{0, 42});
}

TEST_CASE("InterfaceTypes STL usage", "[interface-types][basics]") {
// Make sure that interface types can be used with STL map and set
std::map<TypeWithEnergy, int> counterMap{};

auto empty = TypeWithEnergy::makeEmpty();
counterMap[empty]++;

ExampleHit hit{};
auto wrapper = TypeWithEnergy{hit};
counterMap[wrapper]++;

// No way this implicit conversion could ever lead to a subtle bug ;)
counterMap[hit]++;

REQUIRE(counterMap[empty] == 1);
REQUIRE(counterMap[hit] == 2);
REQUIRE(counterMap[wrapper] == 2);
}

TEST_CASE("InterfaceType from immutable", "[interface-types][basics]") {
using WrapperT = TypeWithEnergy;

Expand Down
Loading