xfeat_lightglue_onnx/onnx_runner.py

233 lines
7.1 KiB
Python

import os, glob, sys, random
from time import perf_counter
import cv2
import numpy as np
import onnxruntime as ort
from PIL import Image
from matplotlib import pyplot as plt
# set seed
np.random.seed(0)
random.seed(0)
# wrapper functions to execute time of functions
def timeit(func):
def wrapper(*args, **kwargs):
start_time = perf_counter()
result = func(*args, **kwargs)
# count time (ms)
exec_time = (perf_counter() - start_time) * 1000
print(f"Execution time: {exec_time:.2f} ms")
return result
return wrapper
def resize_img(img, height=640, width=360, keep_aspect_ratio=False):
if keep_aspect_ratio:
h, w = img.shape[:2]
if w > h:
new_w = width
new_h = int(h * new_w / w)
else:
new_h = height
new_w = int(w * new_h / h)
img = cv2.resize(img, (new_w, new_h))
return img
return cv2.resize(img, (width, height))
def draw_points(img, points, size=2, color=(255, 0, 0), thickness=-1):
for p in points:
cv2.circle(img, tuple((int(p[0]), int(p[1]))), size, color, thickness)
return img
def draw_matches(img0, img1, kpts0, kpts1, matches, scores, threshold=0.1, show_lines=True):
# Convert images to RGB
img0, img1 = cv2.cvtColor(img0, cv2.COLOR_BGR2RGB), cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
# hstack images
vis = np.hstack((img0, img1))
h0, w0 = img0.shape[:2]
# Sort matches by score
sorted_idx = np.argsort(scores)
matches = matches[sorted_idx]
scores = scores[sorted_idx]
# Filter matches by score threshold
valid_matches = matches[scores > threshold]
valid_scores = scores[scores > threshold] # Corresponding valid scores
norm_score = (valid_scores - valid_scores.min()) / (valid_scores.max() - valid_scores.min())
pts = []
# draw lines first
for m, score in zip(valid_matches, norm_score):
pt0 = tuple(map(int, kpts0[m[0]]))
pt1 = tuple(map(int, kpts1[m[1]]))
color = plt.cm.jet(score)
dot_size = score * 2 + 0.2
pts.append((pt0, pt1, dot_size, color))
if show_lines:
plt.plot((pt0[0], pt1[0] + w0), (pt0[1], pt1[1]), color=color, linewidth=0.1)
for pt0, pt1, dot_size, color in pts:
plt.plot(pt0[0], pt0[1], 'o', color=color, markersize=dot_size)
plt.plot(pt1[0] + w0, pt1[1], 'o', color=color, markersize=dot_size)
# Display and save the final image
plt.imshow(vis)
plt.axis('off')
plt.savefig('matches.png', bbox_inches='tight', pad_inches=0, dpi=300)
return valid_matches, valid_scores
def warp_image(img, kpts0, kpts1):
try:
# Find homography
H, _ = cv2.findHomography(kpts0, kpts1, cv2.RANSAC, 5.0)
if H is None:
return img
# Warp image
img_warped = cv2.warpPerspective(img, H, (img.shape[1], img.shape[0]))
return img_warped, H
except Exception as e:
print(f"Error warping image: {e}")
return img, H
def show_img(img):
cv2.imshow('img', img)
cv2.waitKey(0)
cv2.destroyAllWindows()
def normalize_kpts(kpts, im_height, im_width):
if len(kpts.shape) == 3:
kpts = kpts.squeeze(0)
kpts = kpts.copy()
print('im_height:', im_height, 'im_width:', im_width)
kpts[:, 0] = kpts[:, 0] / im_width
kpts[:, 1] = kpts[:, 1] / im_height
return kpts
class OrtRun:
def __init__(self, model_path, force_cpu=True, *args, **kwargs):
self.provider = ['CPUExecutionProvider'] if force_cpu else ort.get_available_providers()
options = ort.SessionOptions()
options.intra_op_num_threads = os.cpu_count()
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
self.session = ort.InferenceSession(model_path, options, providers=self.provider)
self.input_name = self.session.get_inputs()
self.output_name = self.session.get_outputs()
def preprocess(self, data):
return data
def postprocess(self, data):
return data
@timeit
def infer(self, data):
return self.session.run(None, data)
def run(self, data):
data = self.preprocess(data)
inputs = {}
for i, inp in enumerate(self.input_name):
inputs[inp.name] = data[i]
res = self.infer(inputs)
return self.postprocess(res)
class FeatExtractor(OrtRun):
def __init__(self, model_path, force_cpu=True, **kwargs):
super().__init__(model_path, force_cpu, **kwargs)
def preprocess(self, img):
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.transpose(img, (2, 0, 1)) / 255.0
img = np.expand_dims(img, axis=0).astype(np.float32)
return [img]
def postprocess(self, data):
kpt0, desc0, score0 = data
# # sort by score
# idx = np.argsort(score0, axis=0)[::-1]
# kpt0 = kpt0[idx]
# desc0 = desc0[idx]
# # topk = half of the keypoints
# kpt0 = kpt0[:len(kpt0) // 2]
# desc0 = desc0[:len(desc0) // 2]
return kpt0, desc0
class Matcher(OrtRun):
def __init__(self, model_path, force_cpu=True, threshold=0.5, **kwargs):
super().__init__(model_path, force_cpu, **kwargs)
self.threshold = threshold
def preprocess(self, data):
# add batch dim
data = [np.expand_dims(d, axis=0) for d in data]
return data
def postprocess(self, data):
matches, scores = data
matches = matches[scores > self.threshold]
scores = scores[scores > self.threshold]
if len(matches) == 0:
return np.array([]), np.array([])
return matches, scores
if __name__ == '__main__':
extractor = FeatExtractor('onnx/xfeat_2048_640x360.onnx')
matcher = Matcher('onnx/lighterglue_L3.onnx', threshold=0.7)
# im1 = 'assets/hard/image_6.jpg'
# im2 = 'assets/hard/image_7.jpg'
im1 = 'assets/001.jpg'
im2 = 'assets/002.jpg'
# im1 = 'assets/003.png'
# im2 = 'assets/004.png'
# im1 = 'assets/ref.png'
# im2 = 'assets/tgt.png'
im1 = cv2.imread(im1)
im2 = cv2.imread(im2)
h, w = 640, 360
# h, w = 1280, 720
im1 = resize_img(im1, height=h, width=w)
im2 = resize_img(im2, height=h, width=w)
print(im1.shape, im2.shape)
start_time = perf_counter()
kpt1, desc1 = extractor.run(im1)
kpt2, desc2 = extractor.run(im2)
norm_kpt1 = normalize_kpts(kpt1, im_height=im1.shape[0], im_width=im1.shape[1])
norm_kpt2 = normalize_kpts(kpt2, im_height=im2.shape[0], im_width=im2.shape[1])
matches, scores = matcher.run((norm_kpt1, norm_kpt2, desc1, desc2))
print(f'Shape of matches: {matches.shape}, scores: {scores.shape}')
print(f'Inference time: {perf_counter() - start_time:.2f} s')
if len(matches) > 0:
# vis = draw_points(im1, kpt1, size=2)
# show_img(vis)
matches, scores = draw_matches(
im1, im2,
kpt1, kpt2,
matches, scores,
threshold=.0,
show_lines=True
)
print(f"Found {len(matches)} matches above threshold")
else:
print("No matches found")