VScode 里面使用 python 去直接调用 CUDA

上一个 帖子主要分享了如何 去将 C++ 程序 打包成一个package。 我们最后的 目的实际上是想把 CUDA 的程序 打包成 一个 Package , C++ 程序只是起到了桥梁的作用

首先:CUDA 程序 和 C++ 的程序一样, 都有一个 .cu 的源文件和 一个 .h 的头文件

我们的文件 包含 Cpp 文件组成,负责当作 CUDA 和 Python 的桥梁。 还有 对应的 CUDA 的源代码文件和 头文件。将这个cpp 文件命名成 ext.cpp.

#include <torch/extension.h>
#include "interpolation_kernel.h"  ## CUDA 的头文件

PYBIND11_MODULE(TORCH_EXTENSION_NAME,m){
    m.def("trilinear_interpolation",&trilinear_interpolation);
}

cpp_properities.json 配置文件

{
    "configurations": [
        {
            "name": "Linux",
            "includePath": [
                "${workspaceFolder}/**",
                "/home/smiao/anaconda3/envs/Gen_3DGS/lib/python3.8",
                "/home/smiao/anaconda3/envs/Gen_3DGS/lib/python3.8/site-packages/torch/include/",
                "/home/smiao/anaconda3/envs/Gen_3DGS/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/"
            ],
            "defines": [],
            "compilerPath": "/usr/bin/gcc",
            "cStandard": "c17",
            "cppStandard": "gnu++14",
            "intelliSenseMode": "linux-gcc-x64"
        }
    ],
    "version": 4

CUDA 部分:

CUDA 的头文件 *** interpolation_kernel.h ***

#include <torch/extension.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

torch::Tensor trilinear_interpolation(torch::Tensor feats, torch::Tensor point);

对应的 源代码 文件*** interpolation_kernel.cu ***

include 的 头文件 和源代码文件 尽量放在同一级的 目录

#include <torch/extension.h>
#include "interpolation_kernel.h"
template <typename scalar_t>
__global__ void trilinear_fw_kernel(
            const torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> feats,
            const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> points,
            torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> feat_interp
    const int n = blockIdx.x * blockDim.x + threadIdx.x; // 每个 Thread 都有一个 独特的Id
    const int f = blockIdx.y * blockDim.y + threadIdx.y;

    // 不参与计算的 thread 之间返回即可
    if(n >= feats.size(0) && f >= feats.size(2))
        return;
    
    const scalar_t u = (points[n][0]+1) / 2;
    const scalar_t v = (points[n][1]+1) / 2;
    const scalar_t w = (points[n][2]+1) / 2;

    const scalar_t a = (1-v) * (1-w);
    const scalar_t b = (1-v) * w;
    const scalar_t c = v * (1-w);
    const scalar_t d = 1-a-b-c;

    feat_interp[n][f] = (1-u) * ( a*feats[n][0][f]+
                               b*feats[n][1][f]+
                               c*feats[n][2][f]+
                               d*feats[n][3][f]) + 
                            u*(a*feats[n][4][f]+
                               b*feats[n][5][f]+
                               c*feats[n][6][f]+
                               d*feats[n][7][f]);
}

// 编写启动函数 如下:
torch::Tensor trilinear_interpolation(torch::Tensor feats, torch::Tensor points){
    CHECK_CUDA(feats);
    CHECK_CUDA(points);
	const int N = feats.size(0), F = feats.size(2);
    torch::Tensor feat_interp = torch::zeros({N,F}, feats.options());

    // 定义 几个 Block 和 Thread 的数量 
    const dim3 threads(16,16);
    const dim3 blocks((N + threads.x-1)/threads.x,(F + threads.y-1)/threads.y);

    //  启动 Kernel 函数, Kernel 的 函数类型一定是 Void  不会包含任何返回类型
    AT_DISPATCH_FLOATING_TYPES(feats.type(), "trilinear_interpolation", 
    ([&] {
        trilinear_fw_kernel<scalar_t><<<blocks, threads>>>(
            feats.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(),
            points.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
            feat_interp.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>()
        );
    }));
    return feats;
}

配置文件 setup.py 部分:

配置文件的 包含 ** *.cpp 文件 和 *.cu 文件 **
其他的部分应该 尽量不去改变。

from setuptools import setup
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
import os
import glob

os.path.dirname(os.path.abspath(__file__))

setup(
    name="cuda_tutorial",
    version='1.0',
    ext_modules=[
        CUDAExtension(
            name='cuda_tutorial',
            sources=["interpolation_kernel.cu","ext.cpp"], 
              extra_compile_args={'cxx': ['-O2'],
                                'nvcc': ['-O2']}
        )
    ],
    cmdclass={
        'build_ext': BuildExtension
    }
)

最后是安装

pip install .

编写 test.py 的 python 代码对于 cuda 代码进行调用:

if __name__ == '__main__':
    N,F = 65536,256
    feats = torch.rand(N,8,F,device='cuda')
    points = torch.rand(N,3,device='cuda')*2-1
	
	## C++ 没有 feat=feat 这样传递 参数的形式
    t1 = time.time() 
    out_cuda = cuda_tutorial.trilinear_interpolation(feats,points)
    print(f"cuda time {time.time()-t1} s")

    t2 = time.time() 
    out_py = trilinear_interpolation_py(feats=feats,points=points)
    print(f"cuda time {time.time()-t2} s")
    print(torch.allclose(out_cuda,out_py))
    print(out_cuda.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/576573.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Python脚本实现PC端大麦网自动购票(Selenium自动化测试工具)

文章目录 Selenium 简介Selenium webdriver 文档chromedriver&#xff08;谷歌浏览器驱动&#xff09;chromedriver 下载配置环境变量 大麦网购票脚本网页 dom 元素 启用远程调试&#xff08;操作已打开的窗口&#xff09; Selenium 简介 Selenium 是一个用于自动化测试的工具…

目标检测与追踪AI算法模型及边缘计算智能分析网关V4的算法应用

目标检测与追踪是计算机视觉领域中的一个重要任务&#xff0c;主要用于识别图像或视频中的目标&#xff0c;并跟踪它们的运动轨迹。针对这一任务&#xff0c;有许多先进的AI算法模型&#xff0c;例如&#xff1a; YOLO&#xff08;You Only Look Once&#xff09;&#xff1a;…

linux 设备树-of_address_to_resource

实例分析-reg 属性解析(基于ranges属性) /{#address-cells <0x01>;#size-cells <0x01>;soc {compatible "simple-bus";#address-cells <0x01>;#size-cells <0x01>;ranges <0x7e000000 0x3f000000 0x1000000 0x40000000 0x40000000…

怎么把图片转换为二维码?3个步骤轻松制作图片二维码

图片的二维码是怎么做成的呢&#xff1f;现在很多场景下的二维码&#xff0c;用手机扫码可以展现出对应的图片信息。通过这种方式来传递图片对于用户体验与很好的效果&#xff0c;而且也方便制作者通过这种方式来快速完成图片信息的传递&#xff0c;与传统方式相比成本更低&…

客服专用的宝藏快捷回复软件

提起客服&#xff0c;很多人会觉得现在智能机器人的自动回复功能就可以将其替代。我始终不是这么认为的。人工客服始终能为一家店铺乃至一个企业起着非常关键重要的作用。今天就来给大家推荐一个客服专用的宝藏软件——客服宝聊天助手。感兴趣的话&#xff0c;可以发给你的客服…

面向对象开发技术(第三周)

回顾 上一堂课主要学习了面向对象编程与非面向对象编程&#xff08;面向功能、过程编程&#xff09;&#xff0c;本节课就重点来看看面向对象编程中的一个具体思想——抽象 面向对象编程的特性&#xff1a;1、封装性 2、继承性 3、多态性 封装&#xff1a;意味着提供服务接口…

异常风云:解码 Java 异常机制

哈喽&#xff0c;各位小伙伴们&#xff0c;你们好呀&#xff0c;我是喵手。 今天我要给大家分享一些自己日常学习到的一些知识点&#xff0c;并以文字的形式跟大家一起交流&#xff0c;互相学习&#xff0c;一个人虽可以走的更快&#xff0c;但一群人可以走的更远。 我是一名后…

【论文阅读】《Octopus v2: On-device language model for super agent》,端侧大模型的应用案例

今年LLM的发展趋势之一&#xff0c;就是端侧LLM快速发展&#xff0c;超级APP入口之争异常激烈。不过&#xff0c;端侧LLM如何应用&#xff0c;不知道细节就很难理解。正好&#xff0c;《Octopus v2: On-device language model for super agent》这篇文章可以解惑。 对比部署在…

JTAG访问xilinx FPGA的IDCODE

之前调试过xilinx的XVC&#xff08;Xilinx virtual cable&#xff09;&#xff0c;突然看到有人搞wifi-JTAG&#xff08;感兴趣可以参考https://github.com/kholia/xvc-esp8266&#xff09;&#xff0c;也挺有趣的。就突然想了解一下JTAG是如何运作的&#xff0c;例如器件识别&…

普通话水平测试用朗读作品60篇-(练习版)

普通话考试题型有读单音节字词、读多音节字词、朗读作品和命题说话。 具体分值如下&#xff1a; 1、读单音节字词100个&#xff0c;占10分&#xff1b;目的考查应试人普通话声母、韵母和声调的发音。 2、读双音节词语50个&#xff0c;占20分&#xff1b;目的是除了考查应试人声…

骨传导耳机怎么选?精心挑选热销排行前五的骨传导耳机推荐!

近几年&#xff0c;骨传导耳机作为新型蓝牙耳机款式&#xff0c;已经得到大家有效认可&#xff0c;可以说已经适用于日常中的各种场景中&#xff0c;比如运动场景&#xff0c;凭借舒适的佩戴体验和保护运动安全的特点深受到运动爱好者的欢迎&#xff0c;作为一个经验丰富的数码…

Linux网络—DNS域名解析服务

目录 一、BIND域名服务基础 1、DNS系统的作用及类型 DNS系统的作用 DNS系统类型 DNS域名解析工作原理&#xff1a; DNS域名解析查询方式&#xff1a; 2、BIND服务 二、使用BIND构建域名服务器 1、构建主、从域名服务器 1&#xff09;主服务器配置&#xff1a; 2&…

Windows主机入侵检测与防御内核技术深入解析

第2章 模块防御的设计思想 2.1 执行与模块执行 本章内容为介绍模块执行防御。在此我将先介绍“执行”分类&#xff0c;以及“模块执行”在“执行”中的位置和重要性。 2.1.1 初次执行 恶意代码&#xff08;或者行为&#xff09;要在被攻击的机器上执行起来&#xff0c;看起…

C语言----单链表的实现

前面向大家介绍了顺序表以及它的实现&#xff0c;今天我们再来向大家介绍链表中的单链表。 1.链表的概念和结构 1.1 链表的概念 链表是一种在物理结构上非连续&#xff0c;非顺序的一种存储结构。链表中的数据的逻辑结构是由链表中的指针链接起来的。 1.2 链表的结构 链表…

茴香豆:搭建你的RAG智能助理-笔记三

本次课程由书生浦语社区贡献者【北辰】老师讲解【茴香豆&#xff1a;搭建你的 RAG 智能助理】课程 课程视频&#xff1a;https://www.bilibili.com/video/BV1QA4m1F7t4/ 课程文档&#xff1a;Tutorial/huixiangdou/readme.md at camp2 InternLM/Tutorial GitHub 该课程&…

江苏开放大学2024年春《会计基础 050266》第二次任务:第二次过程性考核参考答案

电大搜题 多的用不完的题库&#xff0c;支持文字、图片搜题&#xff0c;包含国家开放大学、广东开放大学、超星等等多个平台题库&#xff0c;考试作业必备神器。 公众号 答案&#xff1a;更多答案&#xff0c;请关注【电大搜题】微信公众号 答案&#xff1a;更多答案&#…

记账本React案例(Redux管理状态)

文章目录 整体架构流程 环境搭建 创建项目 技术细节 一、别名路径配置 1.路径解析配置&#xff08;webpack&#xff09; &#xff0c;将/解析为src/ 2.路径联想配置&#xff08;vsCode&#xff09;&#xff0c;使用vscode编辑器时&#xff0c;自动联想出来src文件夹下的…

【java数据结构-优先级队列向下调整Topk问题,堆的常用的接口详解】

&#x1f308;个人主页&#xff1a;努力学编程’ ⛅个人推荐&#xff1a;基于java提供的ArrayList实现的扑克牌游戏 |C贪吃蛇详解 ⚡学好数据结构&#xff0c;刷题刻不容缓&#xff1a;点击一起刷题 &#x1f319;心灵鸡汤&#xff1a;总有人要赢&#xff0c;为什么不能是我呢 …

SCI一区级 | Matlab实现BES-CNN-GRU-Mutilhead-Attention多变量时间序列预测

SCI一区级 | Matlab实现BES-CNN-GRU-Mutilhead-Attention秃鹰算法优化卷积门控循环单元融合多头注意力机制多变量时间序列预测 目录 SCI一区级 | Matlab实现BES-CNN-GRU-Mutilhead-Attention秃鹰算法优化卷积门控循环单元融合多头注意力机制多变量时间序列预测预测效果基本介绍…

【C++杂货铺】多态

目录 &#x1f308;前言&#x1f308; &#x1f4c1;多态的概念 &#x1f4c1; 多态的定义及实现 &#x1f4c2; 多态的构成条件 &#x1f4c2; 虚函数 &#x1f4c2; 虚函数重写 &#x1f4c2; C11 override 和 final &#x1f4c2; 重载&#xff0c;覆盖&#xff08;重写…
最新文章