|
1 | 1 | #include "hal_core/netlist/gate.h"
|
2 | 2 | #include "hal_core/netlist/net.h"
|
3 | 3 | #include "machine_learning/labels/gate_label.h"
|
| 4 | +#include "netlist_preprocessing/netlist_preprocessing.h" |
| 5 | +#include "nlohmann/json.hpp" |
4 | 6 |
|
5 | 7 | namespace hal
|
6 | 8 | {
|
@@ -73,13 +75,56 @@ namespace hal
|
73 | 75 | {
|
74 | 76 | if (g->get_type()->has_property(gtp))
|
75 | 77 | {
|
| 78 | + if (!g->has_data("preprocessing_information", "multi_bit_indexed_identifiers")) |
| 79 | + { |
| 80 | + log_error("machine_learning", "unable to find indexed identifiers for gate with ID {}", g->get_id()); |
| 81 | + continue; |
| 82 | + } |
| 83 | + |
| 84 | + const std::string json_string = std::get<1>(g->get_data("preprocessing_information", "multi_bit_indexed_identifiers")); |
| 85 | + |
| 86 | + nlohmann::json j = nlohmann::json::parse(json_string); |
| 87 | + std::vector<netlist_preprocessing::indexed_identifier> index_information = j.get<std::vector<netlist_preprocessing::indexed_identifier>>(); |
| 88 | + |
| 89 | + // for each pin, only consider the index information with the least distance |
| 90 | + std::map<std::string, u32> pin_to_min_distance; |
| 91 | + for (const auto& [_name, _index, _origin, pin, _direction, distance] : index_information) |
| 92 | + { |
| 93 | + if (const auto it = pin_to_min_distance.find(pin); it == pin_to_min_distance.end()) |
| 94 | + { |
| 95 | + pin_to_min_distance.insert({pin, distance}); |
| 96 | + } |
| 97 | + else |
| 98 | + { |
| 99 | + pin_to_min_distance.at(pin) = std::min(it->second, distance); |
| 100 | + } |
| 101 | + } |
| 102 | + |
| 103 | + std::map<std::string, std::string> pin_to_net_name; |
| 104 | + for (const auto& [name, _index, _origin, pin, _direction, distance] : index_information) |
| 105 | + { |
| 106 | + if (pin_to_min_distance.at(pin) == distance) |
| 107 | + { |
| 108 | + pin_to_net_name.insert({pin, name}); |
| 109 | + } |
| 110 | + } |
| 111 | + |
76 | 112 | for (const auto& pt : m_pin_types)
|
77 | 113 | {
|
78 | 114 | const auto& pins = g->get_type()->get_pins([&pt](const auto& gt_p) { return gt_p->get_type() == pt; });
|
79 | 115 | for (const auto* p : pins)
|
80 | 116 | {
|
81 |
| - const auto* ep = (p->get_direction() == PinDirection::input) ? g->get_fan_in_endpoint(p) : g->get_fan_out_endpoint(p); |
82 |
| - const auto& net_name = ep->get_net()->get_name(); |
| 117 | + std::string net_name; |
| 118 | + if (const auto it = pin_to_net_name.find(p->get_name()); it != pin_to_net_name.end()) |
| 119 | + { |
| 120 | + net_name = pin_to_net_name.at(p->get_name()); |
| 121 | + } |
| 122 | + else |
| 123 | + { |
| 124 | + const auto* ep = (p->get_direction() == PinDirection::input) ? g->get_fan_in_endpoint(p) : g->get_fan_out_endpoint(p); |
| 125 | + net_name = ep->get_net()->get_name(); |
| 126 | + } |
| 127 | + |
83 | 128 | if (net_name.find(m_key_word) != std::string::npos)
|
84 | 129 | {
|
85 | 130 | return OK(MATCH);
|
|
0 commit comments