From 3d99cd389ea1001238105eab3d54c93f6fa57681 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 13 Mar 2025 13:24:29 -0700 Subject: [PATCH] Add Back-Edge support in Graph builder. Inputs of Graph builder nodes from now on have an additional method to mark them as back edges: ``` node.In("INPUT").AsBackEdge() ``` Similarly to `SetName()` it returns the reference to the input and can be used as: ``` node.In("INPUT").AsBackEdge().SetName("in") ``` PiperOrigin-RevId: 736618624 --- mediapipe/framework/api2/builder.h | 78 +++++++++++- mediapipe/framework/api2/builder_test.cc | 145 +++++++++++++++++++++++ 2 files changed, 218 insertions(+), 5 deletions(-) diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index ce803c175b..f1299892ee 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -98,6 +98,7 @@ class PacketGenerator; struct SourceBase; struct DestinationBase { SourceBase* source = nullptr; + bool back_edge = false; }; struct SourceBase { std::vector dests_; @@ -144,11 +145,15 @@ using AllowCast = std::integral_constant || template class SourceImpl; +template +class DestinationImpl; + // These classes wrap references to the underlying source/destination // endpoints, adding type information and the user-visible API. -template -class DestinationImpl { +template +class DestinationImpl { public: + static constexpr bool kIsSide = false; using Base = DestinationBase; explicit DestinationImpl(std::vector>* vec) @@ -157,8 +162,54 @@ class DestinationImpl { template {}, int> = 0> - DestinationImpl Cast() { - return DestinationImpl(&base_); + DestinationImpl Cast() { + return DestinationImpl(&base_); + } + + // Whether the input stream is a back edge. + // + // By default, MediaPipe requires graphs to be acyclic and treats cycles in a + // graph as errors. To allow MediaPipe to accept a cyclic graph, use/make + // corresponding inputs as back edges. A cyclic graph usually has an obvious + // forward direction, and a back edge goes in the opposite direction. For a + // formal definition of a back edge, please see + // https://en.wikipedia.org/wiki/Depth-first_search. + // + // Equivalent of having "input_stream_info" for an input stream in the config: + // node { + // ... + // input_stream: "TAG:0:stream" + // input_stream_info { + // tag: "TAG:0" + // back_edge: true + // } + // } + DestinationImpl& AsBackEdge() { + base_.back_edge = true; + return *this; + } + + private: + DestinationBase& base_; + + template + friend class SourceImpl; +}; + +template +class DestinationImpl { + public: + static constexpr bool kIsSide = true; + using Base = DestinationBase; + + explicit DestinationImpl(std::vector>* vec) + : DestinationImpl(&GetWithAutoGrow(vec, 0)) {} + explicit DestinationImpl(DestinationBase* base) : base_(*base) {} + + template {}, int> = 0> + DestinationImpl Cast() { + return DestinationImpl(&base_); } private: @@ -916,7 +967,15 @@ class Graph { return absl::OkStatus(); } - std::string TaggedName(const TagIndexLocation& loc, absl::string_view name) { + static std::string TagIndex(const TagIndexLocation& loc) { + if (loc.count <= 1) { + return loc.tag; + } + return absl::StrCat(loc.tag, ":", loc.index); + } + + static std::string TaggedName(const TagIndexLocation& loc, + absl::string_view name) { if (loc.tag.empty()) { // ParseTagIndexName does not allow using explicit indices without tags, // while ParseTagIndex does. @@ -942,6 +1001,11 @@ class Graph { << (loc.tag.empty() ? "(empty)" : loc.tag) << " at index " << loc.index; config->add_input_stream(TaggedName(loc, endpoint.source->name_)); + if (endpoint.back_edge) { + auto* info = config->add_input_stream_info(); + info->set_back_edge(true); + info->set_tag_index(TagIndex(loc)); + } return absl::OkStatus(); })); MP_RETURN_IF_ERROR(node.out_streams_.Visit( @@ -1034,7 +1098,11 @@ class Graph { << type_ << ": Missing source for graph output stream with tag " << (loc.tag.empty() ? "(empty)" : loc.tag) << " at index " << loc.index; + RET_CHECK(!endpoint.back_edge) + << "Graph output: " << (loc.tag.empty() ? "(empty)" : loc.tag) + << " at index " << loc.index << " cannot be a back edge"; config->add_output_stream(TaggedName(loc, endpoint.source->name_)); + return absl::OkStatus(); })); MP_RETURN_IF_ERROR(graph_boundary_.out_streams_.Visit( diff --git a/mediapipe/framework/api2/builder_test.cc b/mediapipe/framework/api2/builder_test.cc index 653ec85723..6afd4cb98d 100644 --- a/mediapipe/framework/api2/builder_test.cc +++ b/mediapipe/framework/api2/builder_test.cc @@ -1,5 +1,6 @@ #include "mediapipe/framework/api2/builder.h" +#include #include #include "absl/strings/string_view.h" @@ -256,6 +257,150 @@ TEST(BuilderTest, BuildGraphSettingSourceLayer) { EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); } +TEST(BuilderTest, CanUseBackEdges) { + Graph graph; + // Graph inputs. + Stream image = graph.In("IMAGE").SetName("image"); + + auto [prev_detections, set_prev_detections_fn] = [&]() { + auto* loopback_node = &graph.AddNode("PreviousLoopbackCalculator"); + image >> loopback_node->In("MAIN"); + auto set_loop_fn = [loopback_node](Stream loop) { + loop >> loopback_node->In("LOOP").AsBackEdge(); + }; + Stream prev_loop = loopback_node->Out("PREV_LOOP"); + return std::pair(prev_loop, set_loop_fn); + }(); + + Stream detections = [&]() { + auto& detection_node = graph.AddNode("ObjectDetectionCalculator"); + image >> detection_node.In("IMAGE"); + prev_detections >> detection_node.In("PREV_DETECTIONS"); + return detection_node.Out("DETECTIONS"); + }(); + + set_prev_detections_fn(detections); + + // Graph outputs. + detections.SetName("detections") >> graph.Out("OUT"); + + CalculatorGraphConfig expected = + mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "PreviousLoopbackCalculator" + input_stream: "LOOP:detections" + input_stream: "MAIN:image" + output_stream: "PREV_LOOP:__stream_0" + input_stream_info { tag_index: "LOOP" back_edge: true } + } + node { + calculator: "ObjectDetectionCalculator" + input_stream: "IMAGE:image" + input_stream: "PREV_DETECTIONS:__stream_0" + output_stream: "DETECTIONS:detections" + } + input_stream: "IMAGE:image" + output_stream: "OUT:detections" + )pb"); + EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); +} + +TEST(BuilderTest, CanUseBackEdgesWithIndex) { + Graph graph; + // Graph inputs. + Stream image = graph.In("IN").SetName("in_data"); + + auto [processed_data, set_back_edge_fn] = [&]() { + auto* back_edge_node = &graph.AddNode("SomeBackEdgeCalculator"); + image >> back_edge_node->In("DATA")[0]; + auto set_back_edge_fn = [back_edge_node](Stream loop) { + loop >> back_edge_node->In("DATA")[1].AsBackEdge(); + }; + Stream processed_data = back_edge_node->Out("PROCESSED_DATA"); + return std::pair(processed_data, set_back_edge_fn); + }(); + + Stream output_data = [&]() { + auto& detection_node = graph.AddNode("SomeOutputDataCalculator"); + image >> detection_node.In("IMAGE"); + processed_data >> detection_node.In("PROCESSED_DATA"); + return detection_node.Out("OUTPUT_DATA"); + }(); + + set_back_edge_fn(output_data); + + // Graph outputs. + output_data.SetName("out_data") >> graph.Out("OUT"); + + CalculatorGraphConfig expected = + mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "SomeBackEdgeCalculator" + input_stream: "DATA:0:in_data" + input_stream: "DATA:1:out_data" + output_stream: "PROCESSED_DATA:__stream_0" + input_stream_info { tag_index: "DATA:1" back_edge: true } + } + node { + calculator: "SomeOutputDataCalculator" + input_stream: "IMAGE:in_data" + input_stream: "PROCESSED_DATA:__stream_0" + output_stream: "OUTPUT_DATA:out_data" + } + input_stream: "IN:in_data" + output_stream: "OUT:out_data" + )pb"); + EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); +} + +TEST(BuilderTest, CanUseBackEdgesWithIndexAndNoTag) { + Graph graph; + // Graph inputs. + Stream image = graph.In("IN").SetName("in_data"); + + auto [processed_data, set_back_edge_fn] = [&]() { + auto* back_edge_node = &graph.AddNode("SomeBackEdgeCalculator"); + image >> back_edge_node->In(0); + auto set_back_edge_fn = [back_edge_node](Stream loop) { + loop >> back_edge_node->In(1).AsBackEdge(); + }; + Stream processed_data = back_edge_node->Out("PROCESSED_DATA"); + return std::pair(processed_data, set_back_edge_fn); + }(); + + Stream output_data = [&]() { + auto& detection_node = graph.AddNode("SomeOutputDataCalculator"); + image >> detection_node.In("IMAGE"); + processed_data >> detection_node.In("PROCESSED_DATA"); + return detection_node.Out("OUTPUT_DATA"); + }(); + + set_back_edge_fn(output_data); + + // Graph outputs. + output_data.SetName("out_data") >> graph.Out("OUT"); + + CalculatorGraphConfig expected = + mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "SomeBackEdgeCalculator" + input_stream: "in_data" + input_stream: "out_data" + output_stream: "PROCESSED_DATA:__stream_0" + input_stream_info { tag_index: ":1" back_edge: true } + } + node { + calculator: "SomeOutputDataCalculator" + input_stream: "IMAGE:in_data" + input_stream: "PROCESSED_DATA:__stream_0" + output_stream: "OUTPUT_DATA:out_data" + } + input_stream: "IN:in_data" + output_stream: "OUT:out_data" + )pb"); + EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); +} + TEST(BuilderTest, CopyableStream) { Graph graph; Stream a = graph.In("A").SetName("a").Cast();