febrero 17, 2023

OJOS – Detección de ojos

Written by

Tras investigar varios días por Internet la información de la función de detección de poses, no conseguimos encontrar nada claro. Nos ponemos a leer los diferentes códigos del programa y conseguimos averiguar lo necesario del funcionamiento, que se aclara a continuación:

  • La matriz de puntos detectados se llama output. Esta matriz se obtiene directamente de pytorch y se refina después con el sistema ‘Non Maximum Supression’.
  • La matriz contiene tantas filas como personas detectadas, que se comprueba en el valor output.shape[0].
  • Cada fila de la matriz contiene puntos situados sobre la imagen, pero los 7 primeros valores no contienen información importante para la pose y se eliminan.
  • Los puntos restantes tras la eliminación contienen la coordenada X, Y y el valor de la fiabilidad de la detección. Estos puntos están ordenados de acuerdo al patrón del esqueleto. El primer punto es la nariz y los dos siguientes son los ojos.

A continuación dejamos el código principal modificado del programa que muestra sólo los tres primeros puntos de las caras en lugar de dibujar todo el esqueleto.

import time
import torch
import cv2
import numpy as np
from torchvision import transforms

from utils.datasets import letterbox
from utils.general  import non_max_suppression_kpt
from utils.plots    import output_to_keypoint, plot_skeleton_kpts


def pose_video(frame):
    mapped_img = frame.copy()
    # Letterbox resizing.
    img = letterbox(frame, input_size, stride=64, auto=True)[0]
    #print(img.shape)
    img_ = img.copy()
    # Convert the array to 4D.
    img = transforms.ToTensor()(img)
    # Convert the array to Tensor.
    img = torch.tensor(np.array([img.numpy()]))
    # Load the image into the computation device.
    img = img.to(device)
    
    # Gradients are stored during training, not required while inference.
    with torch.no_grad():
        t1 = time.time()
        output, _ = model(img)
        t2 = time.time()
        fps = 1/(t2 - t1)
        output = non_max_suppression_kpt(output, 
                                         0.25,    # Conf. Threshold.
                                         0.65,    # IoU Threshold.
                                         nc=1,   # Number of classes.
                                         nkpt=17, # Number of keypoints.
                                         kpt_label=True)
        
        output = output_to_keypoint(output)

    # Change format [b, c, h, w] to [h, w, c] for displaying the image.
    nimg = img[0].permute(1, 2, 0) * 255
    nimg = nimg.cpu().numpy().astype(np.uint8)
    nimg = cv2.cvtColor(nimg, cv2.COLOR_RGB2BGR)

    for idx in range(output.shape[0]):
        cv2.circle(nimg, (int(output[idx,7]), int(output[idx,8])), 5, (255,0, 0), -1)
        cv2.circle(nimg, (int(output[idx,10]), int(output[idx,11])), 5, (255,255, 0), -1)
        cv2.circle(nimg, (int(output[idx,13]), int(output[idx,14])), 5, (255,255, 0), -1)
#        plot_skeleton_kpts(nimg, output[idx, 7:].T, 3)
        
    return nimg, fps


# Change forward pass input size.
input_size = 512
device = torch.device("cuda:0")

# Load keypoint detection model.
weights = torch.load('yolov7-w6-pose.pt', map_location=device)
model = weights['model']
# Load the model in evaluation mode.
_ = model.float().eval()
# Load the model to computation device [cpu/gpu/tpu]
model.to(device)

cap = cv2.VideoCapture(0)
fps = int(cap.get(cv2.CAP_PROP_FPS))
ret, frame = cap.read()
h, w, _ = frame.shape

if __name__ == '__main__':
        while True:
            ret, frame = cap.read()
            
            if not ret:
                print('Unable to read frame. Exiting ..')
                break

            img, fps_ = pose_video(frame)

            cv2.putText(img, 'FPS : {:.2f}'.format(fps_), (120, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0), 2, cv2.LINE_AA)
            cv2.putText(img, 'YOLOv7', (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0), 2, cv2.LINE_AA)

            cv2.imshow('Output', img[...,::-1])
            if (cv2.waitKey(1) == ord('s')):
                break

        cap.release()
        cv2.destroyAllWindows()

Category : OJOS

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *

Proudly powered by WordPress and Sweet Tech Theme