-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
260 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import onnx | ||
import onnx.helper | ||
import os | ||
import onnx_graphsurgeon as gs | ||
import argparse | ||
|
||
def optimize_onnx(model_path="./Maize.onnx"): | ||
print("Optimizing ONNX model") | ||
|
||
# Load the ONNX model | ||
model = onnx.load(model_path) | ||
graph = gs.import_onnx(model) | ||
|
||
print("Graph nodes before optimization:") | ||
for node in graph.nodes: | ||
print(f"Index: {graph.nodes.index(node)}, Node Type: {node.op}, Node Name: {node.name}") | ||
|
||
# Skip nodes from index 1 to 100 while optimizing others | ||
for idx, node in enumerate(graph.nodes): | ||
if 1 <= idx <= 100: | ||
continue # Skip nodes in this range | ||
|
||
# Example optimization: Fuse arithmetic operations (Add with Mul) | ||
if node.op == "Add": | ||
# Check if the input to this "Add" is a multiplication operation (or any other condition) | ||
# We need to check the node that produces the input tensors, not the tensors themselves | ||
input_node_0 = node.inputs[0].inputs[0] if isinstance(node.inputs[0], gs.Variable) and node.inputs[0].inputs else None | ||
|
||
if input_node_0 and input_node_0.op == "Mul": | ||
# Fuse the Add and Mul operations | ||
fused_node = gs.Node(op="Add", inputs=node.inputs[0].inputs, outputs=node.outputs) | ||
graph.nodes.append(fused_node) | ||
|
||
# Remove the original nodes that are now fused | ||
graph.nodes.remove(input_node_0) # Remove the Mul node | ||
graph.nodes.remove(node) # Remove the Add node | ||
|
||
# Clean up, fold constants, and toposort the graph | ||
graph.cleanup().toposort() | ||
graph.fold_constants() | ||
|
||
# Save the optimized model to a new file | ||
optimized_model_path = model_path.replace(".onnx", "_optimized.onnx") | ||
onnx.save(gs.export_onnx(graph), optimized_model_path) | ||
|
||
print(f"Optimized model saved to {optimized_model_path}") | ||
return optimized_model_path | ||
|
||
if __name__ == "__main__": | ||
print("Usage: python3 ONNX_GS.py --model_path=./Maize.onnx") | ||
|
||
parser = argparse.ArgumentParser(description='Optimize the ONNX model using GraphSurgeon') | ||
parser.add_argument('--model_path', type=str, default="./Maize.onnx", required=False, help='Path to the ONNX model file (.onnx)') | ||
args = parser.parse_args() | ||
|
||
optimize_onnx(args.model_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import onnx | ||
import onnx.helper | ||
import os | ||
import onnx_graphsurgeon as gs | ||
|
||
if os.path.exists("Maize_nodes.txt"): | ||
os.remove("Maize_nodes.txt") | ||
|
||
# Load the ONNX model | ||
model = onnx.load("Maize.onnx") | ||
|
||
# Extract the graph from the model | ||
graph = model.graph | ||
|
||
# Save the nodes in the graph | ||
nodes = [] | ||
for node in graph.node: | ||
nodes.append(f"Node Type: {node.op_type}, Inputs: {node.input}, Outputs: {node.output}") | ||
|
||
# Save the nodes to a text file | ||
with open("Maize_nodes.txt", "w") as file: | ||
for node in nodes: | ||
file.write(node + "\n") | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
import time | ||
import onnxruntime as ort | ||
import numpy as np | ||
import cv2 | ||
import logging | ||
import onnx | ||
|
||
logging.basicConfig(format='%(message)s', level=logging.INFO) | ||
|
||
|
||
class Results: | ||
def __init__(self, model_path, image_path, imgsz=640, gpu=True, precision="fp16"): | ||
""" | ||
Initialize the Results class for ONNX Runtime analysis. | ||
""" | ||
self.model_path = model_path | ||
self.image_path = image_path | ||
self.imgsz = imgsz | ||
self.gpu = gpu | ||
self.precision = precision | ||
self.onnx_session = None | ||
self.results = [] | ||
self.inference_times = [] | ||
self.confidence_scores = [] | ||
|
||
# Load the model during initialization | ||
self._load_model() | ||
|
||
def _load_model(self): | ||
"""Load the ONNX model into an inference session.""" | ||
providers = ["CUDAExecutionProvider"] if self.gpu else ["CPUExecutionProvider"] | ||
self.onnx_session = ort.InferenceSession(self.model_path, providers=providers) | ||
logging.info(f"Model loaded from {self.model_path} using {'GPU' if self.gpu else 'CPU'}.") # use gpu if available | ||
|
||
def _get_images(self): | ||
"""Load images from the specified path.""" | ||
images = [] | ||
for image_src in self.image_path: | ||
image = cv2.imread(image_src) | ||
image = cv2.resize(image, self.imgsz) | ||
images.append(image.astype(np.float16) if self.precision == "fp16" else image.astype(np.float32)) | ||
|
||
logging.info(f"{len(images)} images loaded and preprocessed.") | ||
return images | ||
|
||
def predict(self, confidence_threshold=0.5): | ||
""" | ||
Perform inference and collect results. | ||
""" | ||
images = self._get_images() | ||
input_name = self.onnx_session.get_inputs()[0].name | ||
|
||
self.results.clear() | ||
self.inference_times.clear() | ||
self.confidence_scores.clear() | ||
|
||
for image in images: | ||
image = np.expand_dims(image, axis=0) # Add batch dimension | ||
tic = time.perf_counter_ns() | ||
outputs = self.onnx_session.run(None, {input_name: image}) | ||
toc = time.perf_counter_ns() | ||
|
||
# Record inference time | ||
inference_time = (toc - tic) / 1e6 # Milliseconds | ||
self.inference_times.append(inference_time) | ||
|
||
# Process outputs (assuming the first output contains predictions) | ||
output = outputs[0] # Replace with your post-processing logic | ||
filtered_output = [o for o in output if o[-1] >= confidence_threshold] | ||
self.results.append(filtered_output) | ||
self.confidence_scores.extend(o[-1] for o in filtered_output) | ||
|
||
avg_time = sum(self.inference_times) / len(self.inference_times) | ||
logging.info(f"Average inference time: {avg_time:.2f} ms") | ||
return self.results | ||
|
||
def display(self): | ||
"""Display the results.""" | ||
images= self._get_images() | ||
i = 0 | ||
|
||
while True: | ||
image = images[i % len(images)].copy() | ||
for box in self.results[i % len(self.results)]: | ||
x1, y1, x2, y2, score = box[:5] | ||
color = (0, 255, 0) if score >= 0.5 else (0, 0, 255) | ||
cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2) | ||
cv2.putText(image, f"{score:.2f}", (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) | ||
|
||
cv2.imshow('Inference Results', image) | ||
key = cv2.waitKey(0) | ||
|
||
if key == ord('q'): | ||
cv2.destroyAllWindows() | ||
break | ||
elif key == ord('a'): | ||
i -= 1 | ||
elif key == ord('d'): | ||
i += 1 | ||
|
||
def save(self, output_path): | ||
"""Save the results to a file.""" | ||
with open(output_path, 'w') as f: | ||
for result in self.results: | ||
f.write(','.join(str(r) for r in result) + '\n') | ||
logging.info(f"Results saved to {output_path}.") | ||
|
||
def validate(model_path): | ||
try: | ||
onnx_model = onnx.load(model_path) | ||
onnx.checker.check_model(onnx_model) | ||
print("The model is valid!") | ||
return True | ||
except onnx.checker.ValidationError as e: | ||
print("Model validation failed:", e) | ||
return False | ||
|
||
def compare(self, compared_results): | ||
""" | ||
Compare the current results with the provided results. | ||
""" | ||
pass | ||
|
||
|
||
|
||
# sample usage | ||
|
||
img1 = "IMG_1822_14.JPG" | ||
img2 = "IMG_1828_02.JPG" | ||
|
||
img = [img1, img2] | ||
|
||
model = "Maize.onnx" | ||
model_opt = "Maize_optimized.onnx" | ||
|
||
# check image size | ||
cv2.imread(img1).shape | ||
|
||
# check if optimized model is valid | ||
# if not Results.validate(model_opt): | ||
# raise ValueError("The model is not valid. Please check the model.") | ||
# else: | ||
# logging.info("The model is valid.") | ||
# result_analysis = Results( | ||
# model_path=model, | ||
# image_path=img, | ||
# imgsz=cv2.imread(img1).shape[:2], | ||
# gpu=True, | ||
# precision="fp16" | ||
# ) | ||
# results_opt_analysis = Results( | ||
# model_path=model_opt, | ||
# image_path=img, | ||
# imgsz=cv2.imread(img1).shape[:2], | ||
# gpu=True, | ||
# precision="fp16" | ||
# ) | ||
|
||
# visulaize and compare | ||
|
||
# # Run predictions | ||
# results_pred = result_analysis.predict(confidence_threshold=0.5) | ||
# # results_opt_pred = results_opt_analysis.predict(confidence_threshold=0.5) | ||
|
||
# # Visualize results | ||
# result_analysis.show_results() | ||
# # results_opt_analysis.show_results() | ||
|
||
# checking the original model | ||
result_analysis = Results( | ||
model_path=model, | ||
image_path=img, | ||
imgsz=cv2.imread(img1).shape[:2], | ||
gpu=True, | ||
precision="fp16" | ||
) | ||
|
||
result_analysis.predict(confidence_threshold=0.5) | ||
result_analysis.display() |