diff --git a/BUILD b/BUILD index 058d0218bc2..0fef3e59df7 100644 --- a/BUILD +++ b/BUILD @@ -1072,6 +1072,7 @@ tfrt_cc_library( ":tensor_shape_sync_opdefs_inc_gen", "@llvm-project//mlir:IR", "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", ], ) diff --git a/include/tfrt/tensor/opdefs/tensor.h b/include/tfrt/tensor/opdefs/tensor.h index c6940c60e57..e45c1862c02 100644 --- a/include/tfrt/tensor/opdefs/tensor.h +++ b/include/tfrt/tensor/opdefs/tensor.h @@ -26,8 +26,9 @@ using namespace mlir; namespace tfrt { -namespace t { +namespace tfrt_tensor { +// TODO (b/341154040): Pass in "tfrt_tensor" instead of "t". class TensorDialect : public Dialect { public: static StringRef getDialectNamespace() { return "t"; } @@ -45,10 +46,6 @@ class TensorType : public Type::TypeBase { static constexpr StringLiteral name = "tfrt.t.tensor"; }; -} // namespace t -namespace tfrt_tensor { -using TensorType = tfrt::t::TensorType; -using TensorDialect = tfrt::t::TensorDialect; } // namespace tfrt_tensor } // namespace tfrt diff --git a/include/tfrt/tensor/opdefs/tensor.td b/include/tfrt/tensor/opdefs/tensor.td index f3cdd0281e5..56a7a20005c 100644 --- a/include/tfrt/tensor/opdefs/tensor.td +++ b/include/tfrt/tensor/opdefs/tensor.td @@ -27,7 +27,7 @@ include "mlir/IR/OpBase.td" // TFRT tensor dialect. // TODO(b/170246041): Move `TensorType` under the TFRT dialect. def Tensor_Dialect : Dialect { - let name = "t"; + let name = "tfrt_tensor"; let description = [{ The TFRT tensor dialect. @@ -42,10 +42,10 @@ def Tensor_Dialect : Dialect { // Type definitions //===----------------------------------------------------------------------===// def TensorType : DialectType()">, "!t.tensor type">, + CPred<"$_self.isa<::tfrt::tfrt_tensor::TensorType>()">, "!tfrt_tensor.tensor type">, BuildableType<"$_builder.getType<::tfrt::tfrt_tensor::TensorType>()"> { let description = [{ - `!t.tensor type` represents a generic tfrt tensor. + `!trft_tensor.tensor type` represents a generic tfrt tensor. }]; } diff --git a/lib/tensor/opdefs/tensor.cc b/lib/tensor/opdefs/tensor.cc index 8ffa840eb8c..568152ededc 100644 --- a/lib/tensor/opdefs/tensor.cc +++ b/lib/tensor/opdefs/tensor.cc @@ -17,18 +17,23 @@ #include "tfrt/tensor/opdefs/tensor.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" namespace tfrt { -namespace t { +namespace tfrt_tensor { //===----------------------------------------------------------------------===// // TensorShape Dialect //===----------------------------------------------------------------------===// +// TODO (b/341154040): Pass in "tfrt_tensor" into the Dialect constructor +// instead of "t". TensorDialect::TensorDialect(MLIRContext *context) : Dialect(/*name=*/"t", context, TypeID::get()) { allowUnknownTypes(); @@ -60,10 +65,6 @@ void TensorDialect::printType(Type type, DialectAsmPrinter &os) const { llvm_unreachable("unexpected 'tensor' type kind"); } -} // namespace t -namespace tfrt_tensor { -using TensorType = tfrt::t::TensorType; -using TensorDialect = tfrt::t::TensorDialect; } // namespace tfrt_tensor } // namespace tfrt