YOLOv8 目标检测程序,依赖的库最少,使用onnxruntime推理

YOLOv8 目标检测程序,依赖的库最少,使用onnxruntime推理

flyfish
为了方便理解,加入了注释

"""
YOLOv8 目标检测程序
Author: flyfish
Date: 
Description: 该程序使用ONNX运行时进行YOLOv8模型的目标检测。
      它对输入图像进行预处理,运行推理,并对输出进行后处理,在图像上绘制边界框和标签。
"""
import argparse
import cv2
import numpy as np
import onnxruntime as ort

# 定义类别字典,将类别ID映射到类别名称
classes = {
    0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 
    9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 
    16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 
    24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 
    31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 
    37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 
    44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 
    52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 
    60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 
    67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 
    74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'
}

class YOLOv8:
    """YOLOv8目标检测模型类,用于处理推理和可视化。"""

    def __init__(self, onnx_model, input_image, confidence_thres, iou_thres):
        """
        初始化YOLOv8类的实例。

        参数:
            onnx_model: ONNX模型的路径。
            input_image: 输入图像的路径。
            confidence_thres: 用于过滤检测结果的置信度阈值。
            iou_thres: 非极大值抑制的IoU(交并比)阈值。
        """
        self.onnx_model = onnx_model
        self.input_image = input_image
        self.confidence_thres = confidence_thres
        self.iou_thres = iou_thres

        # 加载类别名称
        self.classes = classes
     
        # 为每个类别生成一个颜色调色板
        self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))

    def draw_detections(self, img, box, score, class_id):
        """
        在输入图像上绘制检测到的边界框和标签。

        参数:
            img: 要在其上绘制检测结果的输入图像。
            box: 检测到的边界框。
            score: 对应的检测置信度分数。
            class_id: 检测到的对象的类别ID。

        返回值:
            无
        """
        # 提取边界框的坐标
        x1, y1, w, h = box

        # 获取类别ID对应的颜色
        color = self.color_palette[class_id]

        # 在图像上绘制边界框
        cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)

        # 创建包含类别名称和置信度分数的标签文本
        label = f"{self.classes[class_id]}: {score:.2f}"

        # 计算标签文本的尺寸
        (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)

        # 计算标签文本的位置
        label_x = x1
        label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10

        # 绘制填充矩形作为标签文本的背景
        cv2.rectangle(
            img, (label_x, label_y - label_height), (label_x + label_width, label_y + label_height), color, cv2.FILLED
        )

        # 在图像上绘制标签文本
        cv2.putText(img, label, (label_x, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)

    def preprocess(self):
        """
        在执行推理之前预处理输入图像。

        返回值:
            image_data: 预处理后的图像数据,准备进行推理。
        """
        # 使用OpenCV读取输入图像
        self.img = cv2.imread(self.input_image)

        # 获取输入图像的高度和宽度
        self.img_height, self.img_width = self.img.shape[:2]

        # 将图像的颜色空间从BGR转换为RGB
        img = cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB)

        # 调整图像大小以匹配输入形状
        img = cv2.resize(img, (self.input_width, self.input_height))

        # 通过除以255.0来规范化图像数据
        image_data = np.array(img) / 255.0

        # 转置图像,使通道维度为第一个维度
        image_data = np.transpose(image_data, (2, 0, 1))  # 通道优先

        # 扩展图像数据的维度以匹配预期的输入形状
        image_data = np.expand_dims(image_data, axis=0).astype(np.float32)

        # 返回预处理后的图像数据
        return image_data

    def postprocess(self, input_image, output):
        """
        对模型的输出进行后处理,以提取边界框、置信度分数和类别ID。

        参数:
            input_image (numpy.ndarray): 输入图像。
            output (numpy.ndarray): 模型的输出。

        返回值:
            numpy.ndarray: 带有绘制检测结果的输入图像。
        """
        # 转置并压缩输出以匹配预期的形状
        outputs = np.transpose(np.squeeze(output[0]))

        # 获取输出数组中的行数
        rows = outputs.shape[0]

        # 用于存储检测到的边界框、置信度分数和类别ID的列表
        boxes = []
        scores = []
        class_ids = []

        # 计算边界框坐标的缩放因子
        x_factor = self.img_width / self.input_width
        y_factor = self.img_height / self.input_height

        # 遍历输出数组中的每一行
        for i in range(rows):
            # 从当前行中提取类别分数
            classes_scores = outputs[i][4:]

            # 找到类别分数中的最大值
            max_score = np.amax(classes_scores)

            # 如果最大值大于置信度阈值
            if max_score >= self.confidence_thres:
                # 获取具有最高分数的类别ID
                class_id = np.argmax(classes_scores)

                # 从当前行中提取边界框坐标
                x, y, w, h = outputs[i][0], outputs[i][1], outputs[i][2], outputs[i][3]

                # 计算缩放后的边界框坐标
                left = int((x - w / 2) * x_factor)
                top = int((y - h / 2) * y_factor)
                width = int(w * x_factor)
                height = int(h * y_factor)

                # 将类别ID、置信度分数和边界框坐标添加到各自的列表中
                class_ids.append(class_id)
                scores.append(max_score)
                boxes.append([left, top, width, height])

        # 应用非极大值抑制以过滤重叠的边界框
        indices = cv2.dnn.NMSBoxes(boxes, scores, self.confidence_thres, self.iou_thres)

        # 遍历非极大值抑制后选择的索引
        for i in indices:
            # 获取对应于索引的边界框、置信度分数和类别ID
            box = boxes[i]
            score = scores[i]
            class_id = class_ids[i]

            # 在输入图像上绘制检测结果
            self.draw_detections(input_image, box, score, class_id)

        # 返回修改后的输入图像
        return input_image

    def main(self):
        """
        使用ONNX模型执行推理并返回带有绘制检测结果的输出图像。

        返回值:
            output_img: 带有绘制检测结果的输出图像。
        """
        # 使用ONNX模型创建推理会话并指定执行提供程序
        session = ort.InferenceSession(self.onnx_model, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])

        # 获取模型输入
        model_inputs = session.get_inputs()

        # 存储输入的形状以便稍后使用
        input_shape = model_inputs[0].shape
        self.input_width = input_shape[2]
        self.input_height = input_shape[3]

        # 预处理图像数据
        img_data = self.preprocess()

        # 使用预处理后的图像数据进行推理
        outputs = session.run(None, {model_inputs[0].name: img_data})

        # 对输出进行后处理以获得输出图像
        return self.postprocess(self.img, outputs)  # 输出图像


if __name__ == "__main__":
    # 创建一个参数解析器以处理命令行参数
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="yolov8n.onnx", help="输入您的ONNX模型。")
    parser.add_argument("--img", type=str, default="bus.jpg", help="输入图像的路径。")
    parser.add_argument("--conf-thres", type=float, default=0.5, help="置信度阈值")
    parser.add_argument("--iou-thres", type=float, default=0.5, help="非极大值抑制的IoU阈值")
    args = parser.parse_args()

    # 使用指定的参数创建YOLOv8类的实例
    detection = YOLOv8(args.model, args.img, args.conf_thres, args.iou_thres)

    # 执行目标检测并获取输出图像
    output_image = detection.main()

    # 在窗口中显示输出图像
    cv2.namedWindow("Output", cv2.WINDOW_NORMAL)
    cv2.imshow("Output", output_image)

    # 等待按键以退出
    cv2.waitKey(0)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/761521.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

AQS同步队列、条件队列源码解析

AQS详解 前言AQS几个重要的内部属性字段内部类 Node同步队列 | 阻塞队列等待队列 | 条件队列 重要方法执行链同步队列的获取、阻塞、唤醒加锁代码流程解锁 条件队列的获取、阻塞、唤醒大体流程 调用await()方法1. 将节点加入到条件队列2. 完全释放独占锁3. 等待进入阻塞队列4. …

【刷题汇总--数字统计、两个数组的交集、点击消除(栈)】

C日常刷题积累 今日刷题汇总 - day0011、数字统计1.1、题目1.2、思路1.3、程序实现 2、两个数组的交集2.1、题目2.2、思路2.3、程序实现 3、点击消除(栈)3.1、题目3.2、思路3.3、程序实现 今日刷题汇总 - day001 1、数字统计 1.1、题目 请统计某个给定范围[L, R]的所有整数中…

reactjs18 中使用@reduxjs/toolkit同步异步数据的使用

react18 中使用reduxjs/toolkit 1.安装依赖包 yarn add reduxjs/toolkit react-redux2.创建 store 根目录下面创建 store 文件夹,然后创建 index.js 文件。 import { configureStore } from "reduxjs/toolkit"; import { counterReducer } from "…

【机器学习】语音转文字 - FunASR 的应用与实践(speech to text)

本文将介绍 FunASR,一个多功能语音识别模型,包括其特点、使用方法以及在实际应用中的表现。我们将通过一个简单的示例来展示如何使用 FunASR 将语音转换为文字,并探讨其在语音识别领域的应用前景。 一、引言 随着人工智能技术的不断发展&am…

达梦数据库系列—19. 动态增加实时备库

目录 动态增加实时备库 1、数据准备 2 、配置新备库 2.1配置 dm.ini 2.2配置 dmmal.ini 2.3 配置 dmarch.ini 2.4 配置 dmwatcher.ini 2.5 启动备库 2.6 设置 OGUID 2.7 修改数据库模式 3、 动态添加 MAL 配置 4、 动态添加归档配置 5、 修改监视器 dmmonitor.ini…

软考初级网络管理员__网站单选题

1.以下关于服务器端脚本的说法中,正确的是()。 Script 编写 只能采用VBScript 编写 浏览器不能解释执行 由服务器发送到客户端,客户端负责运行 2.站点首页最常用的文件名是()。 index.html homepage.html resource.html mainfrm.html 3.在HTML…

Vatee万腾平台:引领行业变革,创新未来

在当今这个快速变化的时代,科技的力量正在以前所未有的速度推动着行业的变革。Vatee万腾平台,以其独特的视角和前瞻性的布局,正引领着行业变革的浪潮,创新着未来的发展方向。 Vatee万腾平台是一家专注于科技研发和创新应用的领军企…

面试突击:ConcurrentHashMap 源码详解

本文已收录于:https://github.com/danmuking/all-in-one(持续更新) 前言 哈喽,大家好,我是 DanMu。这篇文章想和大家聊聊 ConcurrentHashMap 相关的知识点。严格来说,ConcurrentHashMap 属于java.lang.cur…

【电源拓扑】PFC

为什么开关电源中都有PFC电路 PFC电路就是功率矫正电路,目的是为了防止杂波对电网产生冲击 AC220V通过整流桥之后电压和电流的波形分析 PFC电路为什么选择是Boost升压电路 PFC电路为什么要把电压升高到400V 为了解决输入电压低于滤波电容电压这个矛盾&#xff0…

LDM-XRNY-102溜槽堵塞开关 JOSEF约瑟 接点容量:5A/380V

工作原理 当物料在溜槽中造成堵塞时,堆积的物料会给溜槽侧壁一个压力,从而推动LDM-XRNY-102溜槽堵塞开关的活动门向外或向内推移(根据具体设计而定)。 当活动门偏转一个设定的角度时,其控制开关会动作,发出…

基于Python的自动化测试框架-Pytest总结-第一弹基础

Pytest总结第一弹基础 入门知识点安装pytest运行pytest测试用例发现规则执行方式命令行执行参数 配置发现规则 如何编写测试Case基础案例断言语句的使用pytest.fail() 和 Exceptions自定义断言函数异常测试测试类形式 pytest的Fixture使用Fixture入门案例使用fixture的Setup、T…

[A133]全志u-boot中的I2C驱动分析

[A133]全志u-boot中的I2C驱动分析 hongxi.zhu 2024-6-27 一、IIC标准读写时序 IIC是高位(MSB)先传输 二、代码流程 2.1主机写数据 brandy/brandy-2.0/u-boot-2018/drivers/i2c/sunxi_i2c.c static int sunxi_i2c_write(struct i2c_adapter *adap, uint8_t chip,uint32_t addr…

深入解析 androidx.databinding.BaseObservable

在现代 Android 开发中,数据绑定 (Data Binding) 是一个重要的技术,它简化了 UI 和数据之间的交互。在数据绑定框架中,androidx.databinding.BaseObservable 是一个关键类,用于实现可观察的数据模型。本文将详细介绍 BaseObservab…

Centos7安装Minio笔记

一、Minio概述 Minio是一款开源的对象存储服务器,可以运行在多种操作系统上,包括Linux、Windows和MacOS等。提供一种简单、可扩展、高可用的对象存储解决方案,支持多种数据格式,包括对象、块和文件等。Minio是一款强大、灵活、可…

基于若依(ruoyi-vue)的周报管理系统

喂wangyinlon 填报人页面 审批人 审批不通过,填报人需要重新填写.

智慧校园新气象:校园气象站

在数字化、智能化的浪潮下,传统校园正在迎来一场革命性的变革。在这场变革中,校园气象站以其独特的功能和魅力,成为推动校园气象科普教育、提升校园品质的重要力量。 一、校园气象站:智慧校园的“气象眼” 校园气象站&#xff0c…

宠物医院管理系统-计算机毕业设计源码07221

目 录 1 绪论 1.1 选题背景和意义 1.2国内外研究现状 1.3论文结构与章节安排 2 宠物医院管理系统系统分析 2.1 可行性分析 2.1.1技术可行性分析 2.1.2 操作可行性分析 2.1.3 法律可行性分析 2.2 系统功能分析 2.2.1 功能性分析 2.2.2 非功能性分析 2.3 系统用例分…

Ansible 最佳实践:现代 IT 运维的利器

Ansible 最佳实践:现代 IT 运维的利器 Ansible 是一种开源的 IT 自动化工具,通过 SSH 协议实现远程节点和管理节点之间的通信,适用于配置管理、应用程序部署、任务自动化等多个场景。本文将介绍 Ansible 的基本架构、主要功能以及最佳实践&a…

为什么80%的码农都做不了架构师?

文章目录 一、技术广度和深度的要求1.1 技术广度1.2 技术深度 二、全局视角和系统思维2.1 全局视角2.2 系统思维 三、沟通能力和团队合作3.1 沟通能力3.2 团队合作 四、业务理解和需求分析4.1 业务理解4.2 需求分析 五、持续学习和创新能力5.1 持续学习5.2 创新能力 六、总结 &…

鸿蒙:页面路由使用

页面路由使用步骤: 1.导入Router模块 2.使用路由功能,以pushUrl模式为例 3.接收参数、返回 4.此时的路由是不能使用的,需要到main_pages.json中进行注册