Skip to content

Commit ffae194

Browse files
committedApr 12, 2024
feat: add ego information handlings
Signed-off-by: ktro2828 <kotaro.uetake@tier4.jp>
1 parent 4d54696 commit ffae194

File tree

8 files changed

+74
-20
lines changed

8 files changed

+74
-20
lines changed
 

‎perception/tensorrt_mtr/CMakeLists.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ if (CUDA_FOUND)
2323
PATH_SUFFIXES lib lib64 bin
2424
DOC "CUDNN library.")
2525
else()
26-
message(FAITAL_ERROR "Can not find CUDA")
26+
message(FATAL_ERROR "Can not find CUDA")
2727
endif()
2828

2929
list(APPEND TRT_PLUGINS "nvinfer")
3030
list(APPEND TRT_PLUGINS "nvonnxparser")
3131
list(APPEND TRT_PLUGINS "nvparsers")
3232
foreach(libName ${TRT_PLUGINS})
33-
find_library(${libName}_lib NAMES ${libName} "/usr" PATH_SUFFIES lib)
33+
find_library(${libName}_lib NAMES ${libName} "/usr" PATH_SUFFIXES lib)
3434
list(APPEND TRT_PLUGINS ${${libName}_lib})
3535
endforeach()
3636

‎perception/tensorrt_mtr/include/tensorrt_mtr/node.hpp

+10-1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include <map>
4343
#include <memory>
4444
#include <string>
45+
#include <utility>
4546
#include <vector>
4647

4748
namespace trt_mtr
@@ -56,6 +57,11 @@ using autoware_auto_perception_msgs::msg::TrackedObject;
5657
using autoware_auto_perception_msgs::msg::TrackedObjects;
5758
using nav_msgs::msg::Odometry;
5859

60+
// TODO(ktro2828): use received ego size topic
61+
constexpr float EGO_LENGTH = 4.0f;
62+
constexpr float EGO_WIDTH = 2.0f;
63+
constexpr float EGO_HEIGHT = 1.0f;
64+
5965
class PolylineTypeMap
6066
{
6167
public:
@@ -139,6 +145,8 @@ class MTRNode : public rclcpp::Node
139145
void updateAgentHistory(
140146
const float current_time, const TrackedObjects::ConstSharedPtr objects_msg);
141147

148+
AgentState extractNearestEgo(const float current_time) const;
149+
142150
/**
143151
* @brief Extract target agents and return corresponding indices.
144152
*
@@ -179,10 +187,11 @@ class MTRNode : public rclcpp::Node
179187
tier4_autoware_utils::TransformListener transform_listener_;
180188

181189
// MTR parameters
182-
std::unique_ptr<MtrConfig> config_ptr_;
190+
std::unique_ptr<MTRConfig> config_ptr_;
183191
std::unique_ptr<TrtMTR> model_ptr_;
184192
PolylineTypeMap polyline_type_map_;
185193
std::shared_ptr<PolylineData> polyline_ptr_;
194+
std::vector<std::pair<float, AgentState>> ego_states_;
186195
std::vector<float> timestamps_;
187196
}; // class MTRNode
188197
} // namespace trt_mtr

‎perception/tensorrt_mtr/include/tensorrt_mtr/trt_mtr.hpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ namespace trt_mtr
3636
/**
3737
* @brief A configuration of MTR.
3838
*/
39-
struct MtrConfig
39+
struct MTRConfig
4040
{
4141
/**
42-
* @brief Construct a new Mtr Config object
42+
* @brief Construct a new instance.
4343
*
4444
* @param target_labels An array of target label names.
4545
* @param num_mode The number of modes.
@@ -49,7 +49,7 @@ struct MtrConfig
4949
* @param intention_point_filepath The path to intention points file.
5050
* @param num_intention_point_cluster The number of clusters for intension points.
5151
*/
52-
MtrConfig(
52+
MTRConfig(
5353
const std::vector<std::string> & target_labels = {"VEHICLE", "PEDESTRIAN", "CYCLIST"},
5454
const size_t num_past = 10, const size_t num_mode = 6, const size_t num_future = 80,
5555
const size_t max_num_polyline = 768, const size_t max_num_point = 20,
@@ -81,7 +81,7 @@ struct MtrConfig
8181
std::array<float, 2> offset_xy;
8282
std::string intention_point_filepath;
8383
size_t num_intention_point_cluster;
84-
};
84+
}; // struct MTRConfig
8585

8686
/**
8787
* @brief A class to inference with MTR.
@@ -101,7 +101,7 @@ class TrtMTR
101101
*/
102102
TrtMTR(
103103
const std::string & model_path, const std::string & precision,
104-
const MtrConfig & config = MtrConfig(), const BatchConfig & batch_config = {1, 1, 1},
104+
const MTRConfig & config = MTRConfig(), const BatchConfig & batch_config = {1, 1, 1},
105105
const size_t max_workspace_size = (1ULL << 30),
106106
const BuildConfig & build_config = BuildConfig());
107107

@@ -122,7 +122,7 @@ class TrtMTR
122122
*
123123
* @return const MtrConfig& The model configuration which can not be updated.
124124
*/
125-
const MtrConfig & config() const { return config_; }
125+
const MTRConfig & config() const { return config_; }
126126

127127
private:
128128
/**
@@ -152,7 +152,7 @@ class TrtMTR
152152
bool postProcess(AgentData & agent_data, std::vector<PredictedTrajectory> & trajectories);
153153

154154
// model parameters
155-
MtrConfig config_;
155+
MTRConfig config_;
156156

157157
std::unique_ptr<MTRBuilder> builder_;
158158
cudaStream_t stream_{nullptr};

‎perception/tensorrt_mtr/lib/include/postprocess/postprocess_kernel.cuh

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
*
2424
* @param B The number of target agents.
2525
* @param M The number of modes.
26-
* @param T The number of future timestmaps.
26+
* @param T The number of future timestamps.
2727
* @param inDim The number of input agent state dimensions.
2828
* @param targetState Source target agent states at latest timestamp, in shape [B*inDim].
2929
* @param outDim The number of output state dimensions.

‎perception/tensorrt_mtr/lib/include/preprocess/polyline_preprocess_kernel.cuh

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ __global__ void calculatePolylineCenterKernel(
8282

8383
/**
8484
* @brief In cases of the number of batch polylines (L) is greater than K,
85-
* extacts the topK elements.
85+
* extracts the topK elements.
8686
*
8787
* @param L The number of source polylines.
8888
* @param K The number of polylines expected as the model input.

‎perception/tensorrt_mtr/package.xml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
<package format="3">
44
<name>tensorrt_mtr</name>
55
<version>0.1.0</version>
6-
<description>ROS 2 Node of Motion Transfomer(a.k.a MTR).</description>
6+
<description>ROS 2 Node of Motion Transfromer(a.k.a MTR).</description>
77
<maintainer email="kotaro.uetake@tier4.jp">kotarouetake</maintainer>
88
<license>Apache-2.0</license>
99

‎perception/tensorrt_mtr/src/node.cpp

+51-6
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
#include <algorithm>
2525
#include <cmath>
26-
#include <utility>
2726

2827
namespace trt_mtr
2928
{
@@ -182,7 +181,7 @@ MTRNode::MTRNode(const rclcpp::NodeOptions & node_options)
182181
declare_parameter<std::string>("intention_point_filepath");
183182
const auto num_intention_point_cluster =
184183
static_cast<size_t>(declare_parameter<int>("num_intention_point_cluster"));
185-
config_ptr_ = std::make_unique<MtrConfig>(
184+
config_ptr_ = std::make_unique<MTRConfig>(
186185
target_labels, num_past, num_mode, num_future, max_num_polyline, max_num_point,
187186
point_break_distance, offset_xy, intention_point_filepath, num_intention_point_cluster);
188187
model_ptr_ = std::make_unique<TrtMTR>(model_path, precision, *config_ptr_.get());
@@ -210,7 +209,7 @@ void MTRNode::callback(const TrackedObjects::ConstSharedPtr object_msg)
210209
return; // No polyline
211210
}
212211

213-
const auto current_time = rclcpp::Time(object_msg->header.stamp).seconds();
212+
const auto current_time = static_cast<float>(rclcpp::Time(object_msg->header.stamp).seconds());
214213

215214
timestamps_.emplace_back(current_time);
216215
// TODO(ktro2828): update timestamps
@@ -289,7 +288,30 @@ void MTRNode::onMap(const HADMapBin::ConstSharedPtr map_msg)
289288

290289
void MTRNode::onEgo(const Odometry::ConstSharedPtr ego_msg)
291290
{
292-
RCLCPP_INFO_STREAM(get_logger(), "Ego msg is received: " << ego_msg->header.frame_id);
291+
const auto current_time = static_cast<float>(rclcpp::Time(ego_msg->header.stamp).seconds());
292+
const auto & position = ego_msg->pose.pose.position;
293+
const auto & twist = ego_msg->twist.twist;
294+
const auto yaw = static_cast<float>(tf2::getYaw(ego_msg->pose.pose.orientation));
295+
float ax = 0.0f, ay = 0.0f;
296+
if (!ego_states_.empty()) {
297+
const auto & latest_state = ego_states_.back();
298+
const auto time_diff = current_time - latest_state.first;
299+
ax = (static_cast<float>(twist.linear.x) - latest_state.second.vx()) / (time_diff + 1e-10f);
300+
ay = static_cast<float>(twist.linear.y) - latest_state.second.vy() / (time_diff + 1e-10f);
301+
}
302+
303+
// TODO(ktro2828): use received ego size topic
304+
ego_states_.emplace_back(std::make_pair(
305+
current_time,
306+
AgentState(
307+
static_cast<float>(position.x), static_cast<float>(position.y),
308+
static_cast<float>(position.z), EGO_LENGTH, EGO_WIDTH, EGO_HEIGHT, yaw,
309+
static_cast<float>(twist.linear.x), static_cast<float>(twist.linear.y), ax, ay, true)));
310+
311+
constexpr size_t max_buffer_size = 100;
312+
if (max_buffer_size < ego_states_.size()) {
313+
ego_states_.erase(ego_states_.begin(), ego_states_.begin());
314+
}
293315
}
294316

295317
bool MTRNode::convertLaneletToPolyline()
@@ -362,19 +384,24 @@ bool MTRNode::convertLaneletToPolyline()
362384
void MTRNode::removeAncientAgentHistory(
363385
const float current_time, const TrackedObjects::ConstSharedPtr objects_msg)
364386
{
365-
// TODO(ktro2828): use ego info
387+
constexpr float time_threshold = 1.0f; // TODO(ktro2828): use parameter
366388
for (const auto & object : objects_msg->objects) {
367389
const auto & object_id = tier4_autoware_utils::toHexString(object.object_id);
368390
if (agent_history_map_.count(object_id) == 0) {
369391
continue;
370392
}
371393

372-
constexpr float time_threshold = 1.0f; // TODO(ktro2828): use parameter
373394
const auto & history = agent_history_map_.at(object_id);
374395
if (history.is_ancient(current_time, time_threshold)) {
375396
agent_history_map_.erase(object_id);
376397
}
377398
}
399+
400+
if (
401+
agent_history_map_.count(EGO_ID) != 0 &&
402+
agent_history_map_.at(EGO_ID).is_ancient(current_time, time_threshold)) {
403+
agent_history_map_.erase(EGO_ID);
404+
}
378405
}
379406

380407
void MTRNode::updateAgentHistory(
@@ -402,6 +429,15 @@ void MTRNode::updateAgentHistory(
402429
}
403430
}
404431

432+
auto ego_state = extractNearestEgo(current_time);
433+
if (agent_history_map_.count(EGO_ID) == 0) {
434+
AgentHistory history(EGO_ID, AgentLabel::VEHICLE, config_ptr_->num_past);
435+
history.update(current_time, ego_state);
436+
} else {
437+
agent_history_map_.at(EGO_ID).update(current_time, ego_state);
438+
}
439+
observed_ids.emplace_back(EGO_ID);
440+
405441
// update unobserved histories with empty
406442
for (auto & [object_id, history] : agent_history_map_) {
407443
if (std::find(observed_ids.cbegin(), observed_ids.cend(), object_id) != observed_ids.cend()) {
@@ -411,6 +447,15 @@ void MTRNode::updateAgentHistory(
411447
}
412448
}
413449

450+
AgentState MTRNode::extractNearestEgo(const float current_time) const
451+
{
452+
auto state = std::min_element(
453+
ego_states_.cbegin(), ego_states_.cend(), [&](const auto & s1, const auto & s2) {
454+
return std::abs(s1.first - current_time) < std::abs(s2.first - current_time);
455+
});
456+
return state->second;
457+
}
458+
414459
std::vector<size_t> MTRNode::extractTargetAgent(const std::vector<AgentHistory> & histories)
415460
{
416461
std::vector<std::pair<size_t, float>> distances;

‎perception/tensorrt_mtr/src/trt_mtr.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
namespace trt_mtr
2222
{
2323
TrtMTR::TrtMTR(
24-
const std::string & model_path, const std::string & precision, const MtrConfig & config,
24+
const std::string & model_path, const std::string & precision, const MTRConfig & config,
2525
const BatchConfig & batch_config, const size_t max_workspace_size,
2626
const BuildConfig & build_config)
2727
: config_(config),

0 commit comments

Comments
 (0)