Skip to content

Latest commit

 

History

History
41 lines (27 loc) · 1 KB

0010_Conv2D.md

File metadata and controls

41 lines (27 loc) · 1 KB

Batched Conv2D

Link: https://github.com/cmeraki/vit.triton/blob/main/vit/kernels/conv2d.py

Author: Romit Jain

Tags: Conv

Description: Implements batched Conv2D kernel for the input image of size: (B, 3, H, W) with the kernel size (K, K). Currently, only stride of (K, K) and no padding are supported

Minimal Usage:

import torch
from vit.kernels.conv2d import conv2d_triton

device = 'cuda:0'
dtype = torch.float32

batch_size=4
height=224
width=224
channels=3

kernels=512
kernel_height=16
kernel_width=16

input = torch.randint(0, 10, (batch_size, channels, height, width)).to(device, dtype)
kernel = torch.randint(0, 10, (kernels, channels, kernel_height, kernel_width)).to(device, dtype)
bias = torch.randn(kernels).to(device, dtype)

y_triton = conv2d_triton(input, kernel, bias)

Triton Version: v2.3.0

Other Notes:
Not the fastest Conv2D implementation as of now. Will work on making it even faster.

Id in triton index: 0010