Skip to content

Commit

Permalink
Change op names to linear_8bit_act_xbit_weight
Browse files Browse the repository at this point in the history
Differential Revision: D63347192

Pull Request resolved: #978
  • Loading branch information
metascroy authored Sep 30, 2024
1 parent ae49375 commit b983f7d
Show file tree
Hide file tree
Showing 28 changed files with 176 additions and 193 deletions.
31 changes: 24 additions & 7 deletions torchao/experimental/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ if(NOT TORCHAO_INCLUDE_DIRS)
endif()

if (NOT TORCHAO_OP_TARGET)
message(FATAL_ERROR "TORCHAO_OP_TARGET is not set. Set it to ATEN or EXECUTORCH.")
message(FATAL_ERROR "TORCHAO_OP_TARGET is not set. Set it to aten or executorch.")
endif()

if (NOT TORCHAO_PARALLEL_BACKEND)
if (TORCHAO_OP_TARGET STREQUAL "ATEN")
set(TORCHAO_PARALLEL_BACKEND "ATEN_OPENMP")
elseif(TORCHAO_OP_TARGET STREQUAL "EXECUTORCH")
set(TORCHAO_PARALLEL_BACKEND "EXECUTORCH")
if (TORCHAO_OP_TARGET STREQUAL "aten")
set(TORCHAO_PARALLEL_BACKEND "aten_openmp")
elseif(TORCHAO_OP_TARGET STREQUAL "executorch")
set(TORCHAO_PARALLEL_BACKEND "executorch")
else()
message(TORCHAO_PARALLEL_BACKEND "TORCHAO_PARALLEL_BACKEND is not set. Please set it directly or set TORCHAO_OP_TARGET to get a default.")
endif()
Expand All @@ -46,9 +46,26 @@ include(CMakePrintHelpers)
message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}")
include_directories(${TORCHAO_INCLUDE_DIRS})

if(TORCHAO_OP_TARGET STREQUAL "aten")
add_library(torchao_ops_${TORCHAO_OP_TARGET} SHARED)
elseif(TORCHAO_OP_TARGET STREQUAL "executorch")
add_library(torchao_ops_${TORCHAO_OP_TARGET} STATIC)
else()
message(FATAL_ERROR "Unknown TORCHAO_OP_TARGET: ${TORCHAO_OP_TARGET}. Please choose one of: aten, executorch.")
endif()

if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
# Defines target torchao_kernels_aarch64
add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64)
add_subdirectory(${TORCHAO_ROOT}/ops/linear)
add_subdirectory(${TORCHAO_ROOT}/ops/linear/linear_a8wxdq_op)
add_subdirectory(${TORCHAO_ROOT}/ops/linear_8bit_act_xbit_weight)

target_link_libraries(
torchao_ops_${TORCHAO_OP_TARGET} PRIVATE
torchao_ops_linear_8bit_act_xbit_weight_${TORCHAO_OP_TARGET}
)
endif()

install(
TARGETS torchao_ops_${TORCHAO_OP_TARGET}
DESTINATION lib
)
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_BUILD_TYPE Release)
add_compile_options("-Wall" "-Werror")

set(TORCHAO_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../../..)
set(TORCHAO_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../..)

include(FetchContent)
FetchContent_Declare(googlebenchmark
Expand All @@ -25,16 +25,20 @@ FetchContent_MakeAvailable(

include_directories(${TORCHAO_INCLUDE_DIRS})

set(TORCHAO_PARALLEL_BACKEND "OPENMP")
add_subdirectory(${TORCHAO_ROOT}/ops/linear ${CMAKE_CURRENT_BINARY_DIR}/torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND})
set(TORCHAO_PARALLEL_BACKEND "openmp")

include(${TORCHAO_ROOT}/Utils.cmake)

add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64)

add_executable(benchmark_linear_operator benchmark_linear_operator.cpp)
add_executable(benchmark_linear_8bit_act_xbit_weight
benchmark_linear_8bit_act_xbit_weight.cpp
${TORCHAO_ROOT}/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp
)
target_link_torchao_parallel_backend(benchmark_linear_8bit_act_xbit_weight "${TORCHAO_PARALLEL_BACKEND}")
target_link_libraries(
benchmark_linear_operator
benchmark_linear_8bit_act_xbit_weight
PRIVATE
benchmark::benchmark
torchao_kernels_aarch64
torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND}
)
target_link_torchao_parallel_backend(benchmark_linear_operator "${TORCHAO_PARALLEL_BACKEND}")
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
#include <benchmark/benchmark.h>
#include <torchao/experimental/kernels/cpu/aarch64/linear/linear.h>
#include <torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h>
#include <torchao/experimental/ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h>
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h>
#include <torchao/experimental/ops/memory.h>
#include <torchao/experimental/ops/parallel.h>
#include <vector>

using namespace torchao::ops::linear::
channelwise_8bit_activation_groupwise_lowbit_weight;
using namespace torchao::ops::linear_8bit_act_xbit_weight;

template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
UKernelConfig get_ukernel_config() {
Expand All @@ -40,8 +39,7 @@ UKernelConfig get_ukernel_config() {
}

template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
static void channelwise_8bit_activation_groupwise_lowbit_weight(
benchmark::State& state) {
static void linear_8bit_act_xbit_weight(benchmark::State& state) {
int m = state.range(0);
int n = state.range(1);
int k = state.range(2);
Expand Down Expand Up @@ -150,19 +148,20 @@ static void channelwise_8bit_activation_groupwise_lowbit_weight(
} \
}

#define BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT( \
weight_nbit) \
BENCHMARK(channelwise_8bit_activation_groupwise_lowbit_weight< \
weight_nbit, \
false /*has_weight_zeros*/, \
false /*has_bias*/, \
false /*has_clamp*/>) \
->ArgsProduct(BENCHMARK_PARAMS) \
->ArgNames( \
#define BENCHMARK_LINEAR_8BIT_ACT_XBIT_WEIGHT(weight_nbit) \
BENCHMARK(linear_8bit_act_xbit_weight< \
weight_nbit, \
false /*has_weight_zeros*/, \
false /*has_bias*/, \
false /*has_clamp*/>) \
->ArgsProduct(BENCHMARK_PARAMS) \
->ArgNames( \
{"m", "n", "k", "group_size", "num_threads", "num_test_cases"});

BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT(3);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT(4);
BENCHMARK_LINEAR_8BIT_ACT_XBIT_WEIGHT(2);
BENCHMARK_LINEAR_8BIT_ACT_XBIT_WEIGHT(3);
BENCHMARK_LINEAR_8BIT_ACT_XBIT_WEIGHT(4);
BENCHMARK_LINEAR_8BIT_ACT_XBIT_WEIGHT(5);

// Run the benchmark
BENCHMARK_MAIN();
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,4 @@ cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \
cmake --build ${CMAKE_OUT}

# Run
case "$1" in
linear_operator) ${CMAKE_OUT}/benchmark_linear_operator; ;;
*) echo "Unknown benchmark: $1. Please specify one of: linear_operator."; exit 1; ;;
esac
${CMAKE_OUT}/benchmark_linear_8bit_act_xbit_weight
17 changes: 0 additions & 17 deletions torchao/experimental/ops/linear/CMakeLists.txt

This file was deleted.

45 changes: 0 additions & 45 deletions torchao/experimental/ops/linear/linear_a8wxdq_op/CMakeLists.txt

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

cmake_minimum_required(VERSION 3.19)

include(${TORCHAO_ROOT}/Utils.cmake)


if(TORCHAO_OP_TARGET STREQUAL "aten")
message(STATUS "Building with TORCHAO_OP_TARGET=aten")
find_package(Torch REQUIRED)
add_library(torchao_ops_linear_8bit_act_xbit_weight_${TORCHAO_OP_TARGET} OBJECT
linear_8bit_act_xbit_weight.cpp
op_linear_8bit_act_xbit_weight_aten.cpp
)
target_link_torchao_parallel_backend(torchao_ops_linear_8bit_act_xbit_weight_${TORCHAO_OP_TARGET} "${TORCHAO_PARALLEL_BACKEND}")
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_${TORCHAO_OP_TARGET} PRIVATE torchao_kernels_aarch64)
target_include_directories(torchao_ops_linear_8bit_act_xbit_weight_${TORCHAO_OP_TARGET} PRIVATE "${TORCH_INCLUDE_DIRS}")
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_${TORCHAO_OP_TARGET} PRIVATE "${TORCH_LIBRARIES}")
target_compile_definitions(torchao_ops_linear_8bit_act_xbit_weight_${TORCHAO_OP_TARGET} PRIVATE USE_ATEN=1)
elseif(TORCHAO_OP_TARGET STREQUAL "executorch")
message(STATUS "Building with TORCHAO_OP_TARGET=executorch")
add_library(torchao_ops_linear_8bit_act_xbit_weight_${TORCHAO_OP_TARGET} OBJECT
linear_8bit_act_xbit_weight.cpp
op_linear_8bit_act_xbit_weight_executorch/w2s.cpp
op_linear_8bit_act_xbit_weight_executorch/w2sz.cpp
op_linear_8bit_act_xbit_weight_executorch/w3s.cpp
op_linear_8bit_act_xbit_weight_executorch/w3sz.cpp
op_linear_8bit_act_xbit_weight_executorch/w4s.cpp
op_linear_8bit_act_xbit_weight_executorch/w4sz.cpp
op_linear_8bit_act_xbit_weight_executorch/w5s.cpp
op_linear_8bit_act_xbit_weight_executorch/w5sz.cpp
)
target_link_torchao_parallel_backend(torchao_ops_linear_8bit_act_xbit_weight_${TORCHAO_OP_TARGET} "${TORCHAO_PARALLEL_BACKEND}")
target_include_directories(torchao_ops_linear_8bit_act_xbit_weight_${TORCHAO_OP_TARGET} PRIVATE "${EXECUTORCH_INCLUDE_DIRS}")
target_compile_definitions(torchao_ops_linear_8bit_act_xbit_weight_${TORCHAO_OP_TARGET} PRIVATE USE_EXECUTORCH=1)
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_${TORCHAO_OP_TARGET} PRIVATE "${EXECUTORCH_LIBRARIES}")
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_${TORCHAO_OP_TARGET} PRIVATE torchao_kernels_aarch64)
else()
message(FATAL_ERROR "Unknown TORCHAO_OP_TARGET: ${TORCHAO_OP_TARGET}. Please choose one of: aten, executorch.")
endif()
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,29 @@ set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../../..)

include_directories(${TORCHAO_INCLUDE_DIRS})

set(TORCHAO_PARALLEL_BACKEND "OPENMP")
add_subdirectory(${TORCHAO_ROOT}/ops/linear ${CMAKE_CURRENT_BINARY_DIR}/torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND})
set(TORCHAO_PARALLEL_BACKEND "openmp")
add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64)

include(${TORCHAO_ROOT}/Utils.cmake)

add_executable(separate_function_wrappers separate_function_wrappers.cpp)
add_executable(separate_function_wrappers
separate_function_wrappers.cpp
${TORCHAO_ROOT}/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp
)
target_link_libraries(
separate_function_wrappers
PRIVATE
torchao_kernels_aarch64
torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND}
)
target_link_torchao_parallel_backend(separate_function_wrappers "${TORCHAO_PARALLEL_BACKEND}")

add_executable(stateful_class_wrapper stateful_class_wrapper.cpp)
add_executable(stateful_class_wrapper
stateful_class_wrapper.cpp
${TORCHAO_ROOT}/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp
)
target_link_libraries(
stateful_class_wrapper
PRIVATE
torchao_kernels_aarch64
torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND}
)
target_link_torchao_parallel_backend(stateful_class_wrapper "${TORCHAO_PARALLEL_BACKEND}")
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@
// LICENSE file in the root directory of this source tree.

#pragma once
#include <torchao/experimental/ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h>
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h>
#include <torchao/experimental/ops/macro.h>
#include <torchao/experimental/ops/memory.h>
#include <cassert>
#include <optional>

namespace torchao::ops::linear::
channelwise_8bit_activation_groupwise_lowbit_weight {
namespace torchao::ops::linear_8bit_act_xbit_weight {

class Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator {
class Linear8BitActXBitWeightOperator {
private:
torchao::aligned_byte_ptr packed_weight_data_{nullptr, nullptr};
int packed_weight_data_size_{0};
Expand All @@ -40,7 +39,7 @@ class Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator {
LinearTileSchedulingPolicy linear_scheduling_policy_;

public:
Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator(
Linear8BitActXBitWeightOperator(
UKernelConfig ukernel_config,
int n,
int k,
Expand Down Expand Up @@ -195,4 +194,4 @@ class Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator {
}
};
} // namespace
// torchao::ops::linear::channelwise_8bit_activation_groupwise_lowbit_weight
// torchao::ops::linear_8bit_act_xbit_weight
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#include <torchao/experimental/kernels/cpu/aarch64/linear/linear.h>
#include <torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h>
#include <torchao/experimental/ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h>
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h>
#include <torchao/experimental/ops/memory.h>
#include <torchao/experimental/ops/parallel.h>
#include <iostream>
Expand All @@ -22,8 +22,7 @@
// one stateful class, but not all surfaces support this (see
// examples/stateful_class_wrapper.cpp for an example of this).

namespace torchao::ops::linear::
channelwise_8bit_activation_groupwise_lowbit_weight {
namespace torchao::ops::linear_8bit_act_xbit_weight {

template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
UKernelConfig get_ukernel_config() {
Expand Down Expand Up @@ -141,11 +140,10 @@ void linear_operator(
}

} // namespace
// torchao::ops::linear::channelwise_8bit_activation_groupwise_lowbit_weight
// torchao::ops::linear_8bit_act_xbit_weight

int main() {
using namespace torchao::ops::linear::
channelwise_8bit_activation_groupwise_lowbit_weight;
using namespace torchao::ops::linear_8bit_act_xbit_weight;

torchao::set_num_threads(8);
std::cout << "Using " << torchao::get_num_threads() << " threads."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#include <torchao/experimental/kernels/cpu/aarch64/linear/linear.h>
#include <torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h>
#include <torchao/experimental/ops/linear/examples/Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator.h>
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/Linear8BitActXBitWeightOperator.h>
#include <torchao/experimental/ops/parallel.h>
#include <iostream>
#include <vector>
Expand All @@ -22,8 +22,7 @@
// examples/separate_function_wrappers.cpp for an example of how to split the
// operations into two steps.

using namespace torchao::ops::linear::
channelwise_8bit_activation_groupwise_lowbit_weight;
using namespace torchao::ops::linear_8bit_act_xbit_weight;

template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
UKernelConfig get_ukernel_config() {
Expand Down Expand Up @@ -81,7 +80,7 @@ int main() {
get_ukernel_config<weight_nbit, has_weight_zeros, has_bias, has_clamp>();

auto linear_operator =
Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator(
Linear8BitActXBitWeightOperator(
ukernel_config,
n,
k,
Expand Down
Loading

0 comments on commit b983f7d

Please sign in to comment.