Skip to content

Commit 3d99cd3

Browse files
MediaPipe Teamcopybara-github
MediaPipe Team
authored andcommitted
Add Back-Edge support in Graph builder.
Inputs of Graph builder nodes from now on have an additional method to mark them as back edges: ``` node.In("INPUT").AsBackEdge() ``` Similarly to `SetName()` it returns the reference to the input and can be used as: ``` node.In("INPUT").AsBackEdge().SetName("in") ``` PiperOrigin-RevId: 736618624
1 parent 70a9f59 commit 3d99cd3

File tree

2 files changed

+218
-5
lines changed

2 files changed

+218
-5
lines changed

mediapipe/framework/api2/builder.h

+73-5
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class PacketGenerator;
9898
struct SourceBase;
9999
struct DestinationBase {
100100
SourceBase* source = nullptr;
101+
bool back_edge = false;
101102
};
102103
struct SourceBase {
103104
std::vector<DestinationBase*> dests_;
@@ -144,11 +145,15 @@ using AllowCast = std::integral_constant<bool, (std::is_same_v<T, AnyType> ||
144145
template <bool IsSide, typename T = internal::Generic>
145146
class SourceImpl;
146147

148+
template <bool IsSide, typename T = internal::Generic>
149+
class DestinationImpl;
150+
147151
// These classes wrap references to the underlying source/destination
148152
// endpoints, adding type information and the user-visible API.
149-
template <bool IsSide, typename T = internal::Generic>
150-
class DestinationImpl {
153+
template <typename T>
154+
class DestinationImpl</*IsSide=*/false, T> {
151155
public:
156+
static constexpr bool kIsSide = false;
152157
using Base = DestinationBase;
153158

154159
explicit DestinationImpl(std::vector<std::unique_ptr<Base>>* vec)
@@ -157,8 +162,54 @@ class DestinationImpl {
157162

158163
template <typename U,
159164
std::enable_if_t<internal_builder::AllowCast<T, U>{}, int> = 0>
160-
DestinationImpl<IsSide, U> Cast() {
161-
return DestinationImpl<IsSide, U>(&base_);
165+
DestinationImpl<kIsSide, U> Cast() {
166+
return DestinationImpl<kIsSide, U>(&base_);
167+
}
168+
169+
// Whether the input stream is a back edge.
170+
//
171+
// By default, MediaPipe requires graphs to be acyclic and treats cycles in a
172+
// graph as errors. To allow MediaPipe to accept a cyclic graph, use/make
173+
// corresponding inputs as back edges. A cyclic graph usually has an obvious
174+
// forward direction, and a back edge goes in the opposite direction. For a
175+
// formal definition of a back edge, please see
176+
// https://en.wikipedia.org/wiki/Depth-first_search.
177+
//
178+
// Equivalent of having "input_stream_info" for an input stream in the config:
179+
// node {
180+
// ...
181+
// input_stream: "TAG:0:stream"
182+
// input_stream_info {
183+
// tag: "TAG:0"
184+
// back_edge: true
185+
// }
186+
// }
187+
DestinationImpl<kIsSide, T>& AsBackEdge() {
188+
base_.back_edge = true;
189+
return *this;
190+
}
191+
192+
private:
193+
DestinationBase& base_;
194+
195+
template <bool Source_IsSide, typename Source_T>
196+
friend class SourceImpl;
197+
};
198+
199+
template <typename T>
200+
class DestinationImpl</*IsSide=*/true, T> {
201+
public:
202+
static constexpr bool kIsSide = true;
203+
using Base = DestinationBase;
204+
205+
explicit DestinationImpl(std::vector<std::unique_ptr<Base>>* vec)
206+
: DestinationImpl(&GetWithAutoGrow(vec, 0)) {}
207+
explicit DestinationImpl(DestinationBase* base) : base_(*base) {}
208+
209+
template <typename U,
210+
std::enable_if_t<internal_builder::AllowCast<T, U>{}, int> = 0>
211+
DestinationImpl<kIsSide, U> Cast() {
212+
return DestinationImpl<kIsSide, U>(&base_);
162213
}
163214

164215
private:
@@ -916,7 +967,15 @@ class Graph {
916967
return absl::OkStatus();
917968
}
918969

919-
std::string TaggedName(const TagIndexLocation& loc, absl::string_view name) {
970+
static std::string TagIndex(const TagIndexLocation& loc) {
971+
if (loc.count <= 1) {
972+
return loc.tag;
973+
}
974+
return absl::StrCat(loc.tag, ":", loc.index);
975+
}
976+
977+
static std::string TaggedName(const TagIndexLocation& loc,
978+
absl::string_view name) {
920979
if (loc.tag.empty()) {
921980
// ParseTagIndexName does not allow using explicit indices without tags,
922981
// while ParseTagIndex does.
@@ -942,6 +1001,11 @@ class Graph {
9421001
<< (loc.tag.empty() ? "(empty)" : loc.tag) << " at index "
9431002
<< loc.index;
9441003
config->add_input_stream(TaggedName(loc, endpoint.source->name_));
1004+
if (endpoint.back_edge) {
1005+
auto* info = config->add_input_stream_info();
1006+
info->set_back_edge(true);
1007+
info->set_tag_index(TagIndex(loc));
1008+
}
9451009
return absl::OkStatus();
9461010
}));
9471011
MP_RETURN_IF_ERROR(node.out_streams_.Visit(
@@ -1034,7 +1098,11 @@ class Graph {
10341098
<< type_ << ": Missing source for graph output stream with tag "
10351099
<< (loc.tag.empty() ? "(empty)" : loc.tag) << " at index "
10361100
<< loc.index;
1101+
RET_CHECK(!endpoint.back_edge)
1102+
<< "Graph output: " << (loc.tag.empty() ? "(empty)" : loc.tag)
1103+
<< " at index " << loc.index << " cannot be a back edge";
10371104
config->add_output_stream(TaggedName(loc, endpoint.source->name_));
1105+
10381106
return absl::OkStatus();
10391107
}));
10401108
MP_RETURN_IF_ERROR(graph_boundary_.out_streams_.Visit(

mediapipe/framework/api2/builder_test.cc

+145
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "mediapipe/framework/api2/builder.h"
22

3+
#include <utility>
34
#include <vector>
45

56
#include "absl/strings/string_view.h"
@@ -256,6 +257,150 @@ TEST(BuilderTest, BuildGraphSettingSourceLayer) {
256257
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
257258
}
258259

260+
TEST(BuilderTest, CanUseBackEdges) {
261+
Graph graph;
262+
// Graph inputs.
263+
Stream<AnyType> image = graph.In("IMAGE").SetName("image");
264+
265+
auto [prev_detections, set_prev_detections_fn] = [&]() {
266+
auto* loopback_node = &graph.AddNode("PreviousLoopbackCalculator");
267+
image >> loopback_node->In("MAIN");
268+
auto set_loop_fn = [loopback_node](Stream<AnyType> loop) {
269+
loop >> loopback_node->In("LOOP").AsBackEdge();
270+
};
271+
Stream<AnyType> prev_loop = loopback_node->Out("PREV_LOOP");
272+
return std::pair(prev_loop, set_loop_fn);
273+
}();
274+
275+
Stream<AnyType> detections = [&]() {
276+
auto& detection_node = graph.AddNode("ObjectDetectionCalculator");
277+
image >> detection_node.In("IMAGE");
278+
prev_detections >> detection_node.In("PREV_DETECTIONS");
279+
return detection_node.Out("DETECTIONS");
280+
}();
281+
282+
set_prev_detections_fn(detections);
283+
284+
// Graph outputs.
285+
detections.SetName("detections") >> graph.Out("OUT");
286+
287+
CalculatorGraphConfig expected =
288+
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
289+
node {
290+
calculator: "PreviousLoopbackCalculator"
291+
input_stream: "LOOP:detections"
292+
input_stream: "MAIN:image"
293+
output_stream: "PREV_LOOP:__stream_0"
294+
input_stream_info { tag_index: "LOOP" back_edge: true }
295+
}
296+
node {
297+
calculator: "ObjectDetectionCalculator"
298+
input_stream: "IMAGE:image"
299+
input_stream: "PREV_DETECTIONS:__stream_0"
300+
output_stream: "DETECTIONS:detections"
301+
}
302+
input_stream: "IMAGE:image"
303+
output_stream: "OUT:detections"
304+
)pb");
305+
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
306+
}
307+
308+
TEST(BuilderTest, CanUseBackEdgesWithIndex) {
309+
Graph graph;
310+
// Graph inputs.
311+
Stream<AnyType> image = graph.In("IN").SetName("in_data");
312+
313+
auto [processed_data, set_back_edge_fn] = [&]() {
314+
auto* back_edge_node = &graph.AddNode("SomeBackEdgeCalculator");
315+
image >> back_edge_node->In("DATA")[0];
316+
auto set_back_edge_fn = [back_edge_node](Stream<AnyType> loop) {
317+
loop >> back_edge_node->In("DATA")[1].AsBackEdge();
318+
};
319+
Stream<AnyType> processed_data = back_edge_node->Out("PROCESSED_DATA");
320+
return std::pair(processed_data, set_back_edge_fn);
321+
}();
322+
323+
Stream<AnyType> output_data = [&]() {
324+
auto& detection_node = graph.AddNode("SomeOutputDataCalculator");
325+
image >> detection_node.In("IMAGE");
326+
processed_data >> detection_node.In("PROCESSED_DATA");
327+
return detection_node.Out("OUTPUT_DATA");
328+
}();
329+
330+
set_back_edge_fn(output_data);
331+
332+
// Graph outputs.
333+
output_data.SetName("out_data") >> graph.Out("OUT");
334+
335+
CalculatorGraphConfig expected =
336+
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
337+
node {
338+
calculator: "SomeBackEdgeCalculator"
339+
input_stream: "DATA:0:in_data"
340+
input_stream: "DATA:1:out_data"
341+
output_stream: "PROCESSED_DATA:__stream_0"
342+
input_stream_info { tag_index: "DATA:1" back_edge: true }
343+
}
344+
node {
345+
calculator: "SomeOutputDataCalculator"
346+
input_stream: "IMAGE:in_data"
347+
input_stream: "PROCESSED_DATA:__stream_0"
348+
output_stream: "OUTPUT_DATA:out_data"
349+
}
350+
input_stream: "IN:in_data"
351+
output_stream: "OUT:out_data"
352+
)pb");
353+
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
354+
}
355+
356+
TEST(BuilderTest, CanUseBackEdgesWithIndexAndNoTag) {
357+
Graph graph;
358+
// Graph inputs.
359+
Stream<AnyType> image = graph.In("IN").SetName("in_data");
360+
361+
auto [processed_data, set_back_edge_fn] = [&]() {
362+
auto* back_edge_node = &graph.AddNode("SomeBackEdgeCalculator");
363+
image >> back_edge_node->In(0);
364+
auto set_back_edge_fn = [back_edge_node](Stream<AnyType> loop) {
365+
loop >> back_edge_node->In(1).AsBackEdge();
366+
};
367+
Stream<AnyType> processed_data = back_edge_node->Out("PROCESSED_DATA");
368+
return std::pair(processed_data, set_back_edge_fn);
369+
}();
370+
371+
Stream<AnyType> output_data = [&]() {
372+
auto& detection_node = graph.AddNode("SomeOutputDataCalculator");
373+
image >> detection_node.In("IMAGE");
374+
processed_data >> detection_node.In("PROCESSED_DATA");
375+
return detection_node.Out("OUTPUT_DATA");
376+
}();
377+
378+
set_back_edge_fn(output_data);
379+
380+
// Graph outputs.
381+
output_data.SetName("out_data") >> graph.Out("OUT");
382+
383+
CalculatorGraphConfig expected =
384+
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
385+
node {
386+
calculator: "SomeBackEdgeCalculator"
387+
input_stream: "in_data"
388+
input_stream: "out_data"
389+
output_stream: "PROCESSED_DATA:__stream_0"
390+
input_stream_info { tag_index: ":1" back_edge: true }
391+
}
392+
node {
393+
calculator: "SomeOutputDataCalculator"
394+
input_stream: "IMAGE:in_data"
395+
input_stream: "PROCESSED_DATA:__stream_0"
396+
output_stream: "OUTPUT_DATA:out_data"
397+
}
398+
input_stream: "IN:in_data"
399+
output_stream: "OUT:out_data"
400+
)pb");
401+
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
402+
}
403+
259404
TEST(BuilderTest, CopyableStream) {
260405
Graph graph;
261406
Stream<int> a = graph.In("A").SetName("a").Cast<int>();

0 commit comments

Comments
 (0)