Skip to content

Commit

Permalink
Use consistent source for branch offsets in reading and writing
Browse files Browse the repository at this point in the history
  • Loading branch information
tmadlener committed Jun 13, 2024
1 parent 8d27bd0 commit 9f09b64
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 44 deletions.
41 changes: 7 additions & 34 deletions src/ROOTReader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,48 +26,21 @@ std::tuple<std::vector<root_utils::CollectionBranches>, std::vector<std::pair<st
createCollectionBranchesIndexBased(TChain* chain, const podio::CollectionIDTable& idTable,
const std::vector<root_utils::CollectionWriteInfoT>& collInfo);

/// Helper struct to get the negative offsets from the end of the branches
/// vector for the different types of generic parameters.
template <typename T>
struct TypeToBranchIndexOffset;

template <>
struct TypeToBranchIndexOffset<int> {
constexpr static int keys = 8;
constexpr static int values = 7;
};

template <>
struct TypeToBranchIndexOffset<float> {
constexpr static int keys = 6;
constexpr static int values = 5;
};

template <>
struct TypeToBranchIndexOffset<double> {
constexpr static int keys = 4;
constexpr static int values = 3;
};

template <>
struct TypeToBranchIndexOffset<std::string> {
constexpr static int keys = 2;
constexpr static int values = 1;
};

template <typename T>
void ROOTReader::readParams(ROOTReader::CategoryInfo& catInfo, podio::GenericParameters& params, bool reloadBranches,
unsigned int localEntry) {
const auto nBranches = catInfo.branches.size();
const auto collBranchIdx = catInfo.branches.size() - root_utils::nParamBranches - 1;
constexpr auto brOffset = root_utils::getGPBranchOffsets<T>();

if (reloadBranches) {
auto& keyBranch = catInfo.branches[nBranches - TypeToBranchIndexOffset<T>::keys].data;
auto& keyBranch = catInfo.branches[collBranchIdx + brOffset.keys].data;
keyBranch = root_utils::getBranch(catInfo.chain.get(), root_utils::getGPKeyName<T>());
auto& valueBranch = catInfo.branches[nBranches - TypeToBranchIndexOffset<T>::values].data;
auto& valueBranch = catInfo.branches[collBranchIdx + brOffset.values].data;
valueBranch = root_utils::getBranch(catInfo.chain.get(), root_utils::getGPValueName<T>());
}

auto keyBranch = catInfo.branches[nBranches - TypeToBranchIndexOffset<T>::keys].data;
auto valueBranch = catInfo.branches[nBranches - TypeToBranchIndexOffset<T>::values].data;
auto keyBranch = catInfo.branches[collBranchIdx + brOffset.keys].data;
auto valueBranch = catInfo.branches[collBranchIdx + brOffset.values].data;

root_utils::ParamStorage<T> storage;
keyBranch->SetAddress(storage.keysPtr());
Expand Down
28 changes: 18 additions & 10 deletions src/ROOTWriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ ROOTWriter::CategoryInfo& ROOTWriter::getCategoryInfo(const std::string& categor

void ROOTWriter::initBranches(CategoryInfo& catInfo, const std::vector<root_utils::StoreCollection>& collections,
/*const*/ podio::GenericParameters& parameters) {
catInfo.branches.reserve(collections.size() +
std::tuple_size_v<podio::SupportedGenericDataTypes> * 2); // collections + parameters
catInfo.branches.reserve(collections.size() + root_utils::nParamBranches); // collections + parameters

// First collections
for (auto& [name, coll] : collections) {
Expand Down Expand Up @@ -126,6 +125,8 @@ void ROOTWriter::initBranches(CategoryInfo& catInfo, const std::vector<root_util
}

fillParams(catInfo, parameters);
// NOTE: The order in which these are created is codified for later use in
// root_utils::getGPBranchOffsets
catInfo.branches.emplace_back(catInfo.tree->Branch(root_utils::intKeyName, &catInfo.intParams.keys));
catInfo.branches.emplace_back(catInfo.tree->Branch(root_utils::intValueName, &catInfo.intParams.values));

Expand All @@ -147,18 +148,25 @@ void ROOTWriter::resetBranches(CategoryInfo& categoryInfo,
root_utils::setCollectionAddresses(coll->getBuffers(), collBranches);
iColl++;
}
// Correct index to point to the last branch of collection data for symmetric
// handling of the offsets in reading and writing
iColl--;

categoryInfo.branches[iColl].data->SetAddress(categoryInfo.intParams.keysPtr());
categoryInfo.branches[iColl + 1].data->SetAddress(categoryInfo.intParams.valuesPtr());
constexpr auto intOffset = root_utils::getGPBranchOffsets<int>();
categoryInfo.branches[iColl + intOffset.keys].data->SetAddress(categoryInfo.intParams.keysPtr());
categoryInfo.branches[iColl + intOffset.values].data->SetAddress(categoryInfo.intParams.valuesPtr());

categoryInfo.branches[iColl + 2].data->SetAddress(categoryInfo.floatParams.keysPtr());
categoryInfo.branches[iColl + 3].data->SetAddress(categoryInfo.floatParams.valuesPtr());
constexpr auto floatOffset = root_utils::getGPBranchOffsets<float>();
categoryInfo.branches[iColl + floatOffset.keys].data->SetAddress(categoryInfo.floatParams.keysPtr());
categoryInfo.branches[iColl + floatOffset.values].data->SetAddress(categoryInfo.floatParams.valuesPtr());

categoryInfo.branches[iColl + 4].data->SetAddress(categoryInfo.doubleParams.keysPtr());
categoryInfo.branches[iColl + 5].data->SetAddress(categoryInfo.doubleParams.valuesPtr());
constexpr auto doubleOffset = root_utils::getGPBranchOffsets<double>();
categoryInfo.branches[iColl + doubleOffset.keys].data->SetAddress(categoryInfo.doubleParams.keysPtr());
categoryInfo.branches[iColl + doubleOffset.values].data->SetAddress(categoryInfo.doubleParams.valuesPtr());

categoryInfo.branches[iColl + 6].data->SetAddress(categoryInfo.stringParams.keysPtr());
categoryInfo.branches[iColl + 7].data->SetAddress(categoryInfo.stringParams.valuesPtr());
constexpr auto stringOffset = root_utils::getGPBranchOffsets<std::string>();
categoryInfo.branches[iColl + stringOffset.keys].data->SetAddress(categoryInfo.stringParams.keysPtr());
categoryInfo.branches[iColl + stringOffset.values].data->SetAddress(categoryInfo.stringParams.valuesPtr());
}

void ROOTWriter::finish() {
Expand Down
30 changes: 30 additions & 0 deletions src/rootUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define PODIO_ROOT_UTILS_H // NOLINT(llvm-header-guard): internal headers confuse clang-tidy

#include "podio/CollectionIDTable.h"
#include "podio/GenericParameters.h"
#include "podio/utilities/RootHelpers.h"

#include "TBranch.h"
Expand Down Expand Up @@ -80,6 +81,35 @@ constexpr auto getGPValueName() {
}
}

/// Small helper struct to get info on the offsets of the branches holding
/// GenericParameter keys and values for a given parameter type
struct GPBranchOffsets {
int keys{-1};
int values{-1};
};

/// The number of branches that we create on top of the collection branches per
/// category
constexpr auto nParamBranches = std::tuple_size_v<podio::SupportedGenericDataTypes> * 2;

/// Get the branch offsets for a given parameter type. In this case it is
/// assumed that the integer branches start immediately after the branche for
/// the collections
template <typename T>
constexpr auto getGPBranchOffsets() {
if constexpr (std::is_same_v<T, int>) {
return GPBranchOffsets{1, 2};
} else if constexpr (std::is_same_v<T, float>) {
return GPBranchOffsets{3, 4};
} else if constexpr (std::is_same_v<T, double>) {
return GPBranchOffsets{5, 6};
} else if constexpr (std::is_same_v<T, std::string>) {
return GPBranchOffsets{7, 8};
} else {
static_assert(sizeof(T) == 0, "Unsupported type for generic parameters");
}
}

/**
* Name of the field with the list of categories for RNTuples
*/
Expand Down

0 comments on commit 9f09b64

Please sign in to comment.