diff --git a/include/podio/ROOTLegacyReader.h b/include/podio/ROOTLegacyReader.h index 9383ca51b..9693bfdb0 100644 --- a/include/podio/ROOTLegacyReader.h +++ b/include/podio/ROOTLegacyReader.h @@ -75,20 +75,23 @@ class ROOTLegacyReader { /// Read the next data entry from which a Frame can be constructed. /// /// @note the category name has to be "events" in this case, as only that - /// category is available for legacy files. + /// category is available for legacy files. Also the collections to read + /// argument will be ignored. /// /// @returns FrameData from which a podio::Frame can be constructed if there /// are still entries left to read. Otherwise a nullptr - std::unique_ptr readNextEntry(const std::string&); + std::unique_ptr readNextEntry(const std::string&, const std::vector& = {}); /// Read the desired data entry from which a Frame can be constructed. /// /// @note the category name has to be "events" in this case, as only that - /// category is available for legacy files. + /// category is available for legacy files. Also the collections to read + /// argument will be ignored. /// /// @returns FrameData from which a podio::Frame can be constructed if the /// desired entry exists. Otherwise a nullptr - std::unique_ptr readEntry(const std::string&, const unsigned entry); + std::unique_ptr readEntry(const std::string&, const unsigned entry, + const std::vector& = {}); /// Get the number of entries for the given name /// diff --git a/include/podio/SIOLegacyReader.h b/include/podio/SIOLegacyReader.h index d490fcd5e..8ea86627e 100644 --- a/include/podio/SIOLegacyReader.h +++ b/include/podio/SIOLegacyReader.h @@ -40,20 +40,23 @@ class SIOLegacyReader { /// there are no more entries left, this returns a nullptr. /// /// @note the category name has to be "events" in this case, as only that - /// category is available for legacy files. + /// category is available for legacy files. Also the collections to read + /// argument will be ignored. /// /// @returns FrameData from which a podio::Frame can be constructed if there /// are still entries left to read. Otherwise a nullptr - std::unique_ptr readNextEntry(const std::string&); + std::unique_ptr readNextEntry(const std::string&, const std::vector& = {}); /// Read the desired data entry from which a Frame can be constructed. /// /// @note the category name has to be "events" in this case, as only that - /// category is available for legacy files. + /// category is available for legacy files. Also the collections to read + /// argument will be ignored. /// /// @returns FrameData from which a podio::Frame can be constructed if the /// desired entry exists. Otherwise a nullptr - std::unique_ptr readEntry(const std::string&, const unsigned entry); + std::unique_ptr readEntry(const std::string&, const unsigned entry, + const std::vector& = {}); /// Get the number of entries for the given name /// diff --git a/python/podio/base_reader.py b/python/podio/base_reader.py index e078aa4ad..d834c1922 100644 --- a/python/podio/base_reader.py +++ b/python/podio/base_reader.py @@ -35,17 +35,21 @@ def categories(self): """ return self._categories - def get(self, category): + def get(self, category, coll_names=None): """Get an iterator with access functionality for a given category. Args: category (str): The name of the desired category + coll_names (list[str]): The list of collections to read (optional, + all available collections will by default) Returns: FrameCategoryIterator: The iterator granting access to all Frames of the desired category """ - return FrameCategoryIterator(self._reader, category) + if self.is_legacy and coll_names: + raise ValueError("Legacy readers do not support selective reading") + return FrameCategoryIterator(self._reader, category, coll_names) @property def is_legacy(self): diff --git a/python/podio/frame_iterator.py b/python/podio/frame_iterator.py index 1fa97d828..2ba11960b 100644 --- a/python/podio/frame_iterator.py +++ b/python/podio/frame_iterator.py @@ -11,15 +11,18 @@ class FrameCategoryIterator: reader as well as accessing specific entries """ - def __init__(self, reader, category): + def __init__(self, reader, category, coll_names=None): """Construct the iterator from the reader and the category. Args: reader (Reader): Any podio reader offering access to Frames category (str): The category name of the Frames to be iterated over + coll_names (list[str]): The list of collections to read (optional, + all available collections will by default) """ self._reader = reader self._category = category + self._coll_names = coll_names or [] def __iter__(self): """The trivial implementation for the iterator protocol.""" @@ -27,9 +30,12 @@ def __iter__(self): def __next__(self): """Get the next available Frame or stop.""" - frame_data = self._reader.readNextEntry(self._category) - if frame_data: - return Frame(std.move(frame_data)) + try: + frame_data = self._reader.readNextEntry(self._category, self._coll_names) + if frame_data: + return Frame(std.move(frame_data)) + except std.invalid_argument as e: + raise ValueError(e.what()) from e raise StopIteration @@ -52,7 +58,7 @@ def __getitem__(self, entry): raise IndexError try: - frame_data = self._reader.readEntry(self._category, entry) + frame_data = self._reader.readEntry(self._category, entry, self._coll_names) except std.bad_function_call: print( "Error: Unable to read an entry of the input file. This can " @@ -62,6 +68,8 @@ def __getitem__(self, entry): "library folder with your data model\n" ) raise + except std.invalid_argument as e: + raise ValueError(e.what()) from e if frame_data: return Frame(std.move(frame_data)) diff --git a/python/podio/test_Reader.py b/python/podio/test_Reader.py index d6fc90640..24fb39a06 100644 --- a/python/podio/test_Reader.py +++ b/python/podio/test_Reader.py @@ -89,6 +89,24 @@ def test_invalid_datamodel_version(self): with self.assertRaises(KeyError): self.reader.current_file_version("non-existant-model") + def test_limited_collections(self): + """Make sure only reading a subset of collections works""" + # We only do bare checks here as more extensive tests are already done + # on the c++ side + event = self.reader.get("events", ["hits", "info", "links"])[0] + self.assertEqual(set(event.getAvailableCollections()), {"hits", "info", "links"}) + + def test_invalid_limited_collections(self): + """Ensure that requesting non existant collections raises a value error""" + with self.assertRaises(ValueError): + events = self.reader.get("events", ["non-existent-collection"]) + _ = events[0] + + with self.assertRaises(ValueError): + events = self.reader.get("events", ["non-existent-collection"]) + for _ in events: + pass + class LegacyReaderTestCaseMixin: """Common test cases for the legacy readers python bindings. diff --git a/src/ROOTLegacyReader.cc b/src/ROOTLegacyReader.cc index 87f7a4e1b..d94303aa8 100644 --- a/src/ROOTLegacyReader.cc +++ b/src/ROOTLegacyReader.cc @@ -17,14 +17,16 @@ namespace podio { -std::unique_ptr ROOTLegacyReader::readNextEntry(const std::string& name) { +std::unique_ptr ROOTLegacyReader::readNextEntry(const std::string& name, + const std::vector&) { if (name != m_categoryName) { return nullptr; } return readEntry(); } -std::unique_ptr ROOTLegacyReader::readEntry(const std::string& name, unsigned entry) { +std::unique_ptr ROOTLegacyReader::readEntry(const std::string& name, unsigned entry, + const std::vector&) { if (name != m_categoryName) { return nullptr; } diff --git a/src/SIOLegacyReader.cc b/src/SIOLegacyReader.cc index dcf3d9d18..dae37b964 100644 --- a/src/SIOLegacyReader.cc +++ b/src/SIOLegacyReader.cc @@ -23,7 +23,7 @@ void SIOLegacyReader::openFile(const std::string& filename) { readCollectionIDTable(); } -std::unique_ptr SIOLegacyReader::readNextEntry(const std::string& name) { +std::unique_ptr SIOLegacyReader::readNextEntry(const std::string& name, const std::vector&) { if (name != m_categoryName) { return nullptr; } @@ -47,7 +47,8 @@ std::unique_ptr SIOLegacyReader::readNextEntry(const std::string& m_tableUncLength); } -std::unique_ptr SIOLegacyReader::readEntry(const std::string& name, const unsigned entry) { +std::unique_ptr SIOLegacyReader::readEntry(const std::string& name, const unsigned entry, + const std::vector&) { if (name != m_categoryName) { return nullptr; }