Skip to content

Commit

Permalink
Refactor code for improved readability by adjusting line breaks and f…
Browse files Browse the repository at this point in the history
…ormatting in layout and test files
  • Loading branch information
LeiWang1999 committed Feb 9, 2025
1 parent fd8f421 commit bf1fdf7
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 38 deletions.
7 changes: 5 additions & 2 deletions examples/plot_layout/fragment_mma_load_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from tvm.tir import IndexMap
from tilelang.intrinsics.utils import get_mma_micro_size

def make_mma_load_base_layout(dtype: str = "float16", matrix: Literal["A", "B"] = "A", transposed: bool = False) -> T.Fragment:

def make_mma_load_base_layout(dtype: str = "float16",
matrix: Literal["A", "B"] = "A",
transposed: bool = False) -> T.Fragment:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
Expand Down Expand Up @@ -61,7 +64,6 @@ def make_mma_load_base_layout(dtype: str = "float16", matrix: Literal["A", "B"]

transform_func: Callable = transform_func_sr if is_sr_axis_order else transform_func_rs


micro_size_s, _, micro_size_r = get_mma_micro_size(dtype)

transform_func = transform_func
Expand All @@ -88,6 +90,7 @@ def forward_index(i: int, j: int) -> int:
)
return base_fragment


block_rows = 2
block_cols = 2
warp_rows = 4
Expand Down
25 changes: 18 additions & 7 deletions testing/python/primitives/test_tilelang_primitives_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,24 @@ def ref_program(A, B):


def test_gemm_f16f16f16_nt_ssr():
run_matmul_ssr(16, 16, 16, False, True,
"float16", "float16", "float16", 16, 16, 16, 0, num_threads=32)
run_matmul_ssr(128, 128, 128, False, True,
"float16", "float16", "float16", 32, 32, 32, 0, num_threads=64)
run_matmul_ssr(1024, 1024, 1024, False, True,
"float16", "float16", "float16", 128, 128, 32, 2, num_threads=128)

run_matmul_ssr(
16, 16, 16, False, True, "float16", "float16", "float16", 16, 16, 16, 0, num_threads=32)
run_matmul_ssr(
128, 128, 128, False, True, "float16", "float16", "float16", 32, 32, 32, 0, num_threads=64)
run_matmul_ssr(
1024,
1024,
1024,
False,
True,
"float16",
"float16",
"float16",
128,
128,
32,
2,
num_threads=128)


def matmul_rsr(
Expand Down
25 changes: 12 additions & 13 deletions tilelang/layout/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tilelang.layout import Layout
from typing import List


@tvm._ffi.register_object("tl.Fragment")
class Fragment(Layout):
"""
Expand All @@ -22,14 +23,12 @@ class Fragment(Layout):
# Disable the linter warning about not calling super().__init__()
# because this object is created via TVM's FFI constructor mechanism.
# pylint: disable=super-init-not-called
def __init__(
self,
shape,
forward_fn=None,
forward_thread_fn=None,
replicate=1,
forward_index_fn=None
):
def __init__(self,
shape,
forward_fn=None,
forward_thread_fn=None,
replicate=1,
forward_index_fn=None):
"""
Initialize the Fragment with iteration variables and optional thread replication.
Expand Down Expand Up @@ -121,7 +120,10 @@ def get_thread_size(self):
"""
return _ffi_api.Fragment_thread_size(self)

def repeat(self, repeats, repeat_on_thread: bool = False, lower_dim_first: bool = True) -> "Fragment":
def repeat(self,
repeats,
repeat_on_thread: bool = False,
lower_dim_first: bool = True) -> "Fragment":
"""
Returns a new Fragment that repeats the iteration space a given number of times.
Expand Down Expand Up @@ -190,10 +192,7 @@ def map_forward_thread(self, indices: List[PrimExpr]) -> PrimExpr:
forward_thread = self.thread
# Construct an IndexMap to map the provided args into the final thread index
index_map = IndexMap(
initial_indices=forward_vars,
final_indices=[forward_thread],
inverse_index_map=None
)
initial_indices=forward_vars, final_indices=[forward_thread], inverse_index_map=None)
return index_map.map_indices(indices)

def __repr__(self):
Expand Down
11 changes: 6 additions & 5 deletions tilelang/layout/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tilelang import _ffi_api
from typing import List


# Register the Layout class as a TVM object under the name "tl.Layout"
@tvm._ffi.register_object("tl.Layout")
class Layout(Node):
Expand All @@ -25,7 +26,7 @@ def __init__(self, shape, forward_fn):
A function that maps index variables to their computed forward index.
"""
forward_vars = [] # List to store IterVars corresponding to each shape dimension

# Create an IterVar for each dimension in the shape
for idx, size in enumerate(shape):
# Define an IterVar over the range [0, size) with an associated variable name
Expand All @@ -34,10 +35,10 @@ def __init__(self, shape, forward_fn):

# Extract the variable references from the IterVars
vars = [iv.var for iv in forward_vars]

# Compute the forward index using the provided forward function
forward_index = forward_fn(*vars)

# Ensure forward_index is a list (to handle cases where a single expression is returned)
if isinstance(forward_index, PrimExpr):
forward_index = [forward_index]
Expand Down Expand Up @@ -106,10 +107,10 @@ def map_forward_index(self, indices: List[PrimExpr]) -> PrimExpr:
"""
# Retrieve the iteration variables used in the layout transformation
forward_vars = self.get_forward_vars()

# Retrieve the computed forward index expressions
forward_indexes = self.index

# Construct an IndexMap to map the input indices to the computed output indices
index_map = IndexMap(
initial_indices=forward_vars, # The original iteration variables
Expand Down
2 changes: 1 addition & 1 deletion tilelang/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from .plot_layout import plot_layout
from .plot_layout import plot_layout # noqa: F401
35 changes: 25 additions & 10 deletions tilelang/tools/plot_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@

import tilelang.language as T

def plot_layout(layout: T.Layout, save_directory="./tmp", name: str = "layout", colormap: str = "RdPu", verbose: bool = False) -> None:

def plot_layout(layout: T.Layout,
save_directory="./tmp",
name: str = "layout",
colormap: str = "RdPu",
verbose: bool = False) -> None:
"""
Plot the layout of a buffer.
Expand Down Expand Up @@ -89,26 +94,36 @@ def plot_layout(layout: T.Layout, save_directory="./tmp", name: str = "layout",

color = colors[thread_id] # Select color based on thread ID
# Create a rectangle patch for visualization
rect = patches.Rectangle((j, i), 1, 1, linewidth=0.5,
edgecolor='black', facecolor=color)
rect = patches.Rectangle((j, i),
1,
1,
linewidth=0.5,
edgecolor='black',
facecolor=color)
ax.add_patch(rect) # Add the rectangle to the plot

# Add text annotations inside the rectangles
text = f"T{thread_id}\nL{local_id}"
ax.text(j + 0.5, i + 0.5, text, ha='center', va='center',
color='black', fontsize=font_size)
ax.text(
j + 0.5, i + 0.5, text, ha='center', va='center', color='black', fontsize=font_size)

# Add row labels to the left side of the plot
for i in range(nrows):
text = f"row {i}"
ax.text(-0.75, i + 0.5, text, ha='center', va='center',
color='black', fontsize=font_size)
ax.text(-0.75, i + 0.5, text, ha='center', va='center', color='black', fontsize=font_size)

# Add column labels at the top of the plot
for j in range(ncols):
text = f"col {j}"
ax.text(j + 0.5, -0.5, text, ha='center', va='center',
color='black', fontsize=font_size, rotation=45)
ax.text(
j + 0.5,
-0.5,
text,
ha='center',
va='center',
color='black',
fontsize=font_size,
rotation=45)

# Set the plot limits
ax.set_xlim(0, ncols)
Expand All @@ -124,7 +139,7 @@ def plot_layout(layout: T.Layout, save_directory="./tmp", name: str = "layout",

# Save the figure in multiple formats
plt.tight_layout()

# Save as PDF
pdf_path = tmp_directory / f"{name}.pdf"
plt.savefig(pdf_path, bbox_inches="tight")
Expand Down

0 comments on commit bf1fdf7

Please sign in to comment.