Skip to content

Commit

Permalink
Update mlir tests to reflect new namespace change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 638391003
  • Loading branch information
sagyakwa authored and copybara-github committed May 29, 2024
1 parent edb0d2c commit 4662f75
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 31 deletions.
8 changes: 4 additions & 4 deletions backends/cpu/mlir_tests/mnist/btf_kernels.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ func.func @tensor_io() {
%one = tfrt.constant.i32 1
%two = tfrt.constant.i32 2

%t0 = "btf.read_dense_tensor.i32.2"(%path, %zero) : (!tfrt.string, i32) -> (!t.tensor)
%t0 = "btf.read_dense_tensor.i32.2"(%path, %zero) : (!tfrt.string, i32) -> (!tfrt_tensor.tensor)
// CHECK-NEXT: shape = [2, 2], values = [1, 2, 3, 4]
%c1 = tfrt_dht.print_tensor %t0, %c0

%t1 = "btf.read_dense_tensor.i32.1"(%path, %one) : (!tfrt.string, i32) -> (!t.tensor)
%t1 = "btf.read_dense_tensor.i32.1"(%path, %one) : (!tfrt.string, i32) -> (!tfrt_tensor.tensor)
// CHECK-NEXT: shape = [5], values = [0, 1, 2, 3, 4]
%c2 = tfrt_dht.print_tensor %t1, %c1

%t2 = "btf.read_dense_tensor.i32.1"(%path, %two) : (!tfrt.string, i32) -> (!t.tensor)
%t2 = "btf.read_dense_tensor.i32.1"(%path, %two) : (!tfrt.string, i32) -> (!tfrt_tensor.tensor)
// CHECK-NEXT: shape = [0], values = []
%c3 = tfrt_dht.print_tensor %t2, %c2

Expand All @@ -49,7 +49,7 @@ func.func @tensor_io_invalid_path() {
%path = "tfrt_test.get_string"() { value = "/tmp/invalid_path" } : () -> !tfrt.string
%zero = tfrt.constant.i32 0
// expected-error @+1 {{failed to open file /tmp/invalid_path for reading}}
%t0 = "btf.read_dense_tensor.i32.2"(%path, %zero) : (!tfrt.string, i32) -> (!t.tensor)
%t0 = "btf.read_dense_tensor.i32.2"(%path, %zero) : (!tfrt.string, i32) -> (!tfrt_tensor.tensor)

tfrt.return
}
20 changes: 10 additions & 10 deletions backends/cpu/mlir_tests/resnet/max_pool.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,20 @@ func.func @test_max_pool_2d_f32_0() {

%input_index = tfrt.constant.i32 0
%input = "btf.read_dense_tensor.f32.4"(%path, %input_index)
: (!tfrt.string, i32) -> (!t.tensor)
: (!tfrt.string, i32) -> (!tfrt_tensor.tensor)

%expected_index = tfrt.constant.i32 1
%expected = "btf.read_dense_tensor.f32.4"(%path, %expected_index)
: (!tfrt.string, i32) -> (!t.tensor)
: (!tfrt.string, i32) -> (!tfrt_tensor.tensor)

%output = "tfrt_dht.create_uninitialized_tensor.f32.4"() { shape = [2 : i64, 1 : i64, 1 : i64, 6 : i64] }
: () -> !t.tensor
: () -> !tfrt_tensor.tensor
%ch1 = "tfrt_test.max_pooling_2d.f32"(%input, %output, %ch0)
{ padding = "valid", pool_size = [3 : i32, 3 : i32], strides = [3 : i32, 3 : i32] }
: (!t.tensor, !t.tensor, !tfrt.chain) -> !tfrt.chain
: (!tfrt_tensor.tensor, !tfrt_tensor.tensor, !tfrt.chain) -> !tfrt.chain

%cmp, %ch2 = "tfrt_dht.tensor_allclose.f32"(%expected, %output, %ch1)
: (!t.tensor, !t.tensor, !tfrt.chain) -> (i1, !tfrt.chain)
: (!tfrt_tensor.tensor, !tfrt_tensor.tensor, !tfrt.chain) -> (i1, !tfrt.chain)

// CHECK: int1 = 1
tfrt.print.i1 %cmp, %ch2
Expand All @@ -57,20 +57,20 @@ func.func @test_max_pool_2d_f32_1() {

%input_index = tfrt.constant.i32 2
%input = "btf.read_dense_tensor.f32.4"(%path, %input_index)
: (!tfrt.string, i32) -> (!t.tensor)
: (!tfrt.string, i32) -> (!tfrt_tensor.tensor)

%expected_index = tfrt.constant.i32 3
%expected = "btf.read_dense_tensor.f32.4"(%path, %expected_index)
: (!tfrt.string, i32) -> (!t.tensor)
: (!tfrt.string, i32) -> (!tfrt_tensor.tensor)

%output = "tfrt_dht.create_uninitialized_tensor.f32.4"() { shape = [2 : i64, 3 : i64, 3 : i64, 6 : i64] }
: () -> !t.tensor
: () -> !tfrt_tensor.tensor
%ch1 = "tfrt_test.max_pooling_2d.f32"(%input, %output, %ch0)
{ padding = "same", pool_size = [3 : i32, 3 : i32], strides = [2 : i32, 2 : i32] }
: (!t.tensor, !t.tensor, !tfrt.chain) -> !tfrt.chain
: (!tfrt_tensor.tensor, !tfrt_tensor.tensor, !tfrt.chain) -> !tfrt.chain

%cmp, %ch2 = "tfrt_dht.tensor_allclose.f32"(%expected, %output, %ch1)
: (!t.tensor, !t.tensor, !tfrt.chain) -> (i1, !tfrt.chain)
: (!tfrt_tensor.tensor, !tfrt_tensor.tensor, !tfrt.chain) -> (i1, !tfrt.chain)

// CHECK: int1 = 1
tfrt.print.i1 %cmp, %ch2
Expand Down
20 changes: 10 additions & 10 deletions backends/cpu/mlir_tests/resnet/resnet_tensor_kernels.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ func.func @test_flatten_f32() {

%t1 = "tfrt_dht.create_uninitialized_tensor.f32.4"()
{ shape = [1 : i64, 2 : i64, 2 : i64, 1 : i64] }
: () -> !t.tensor
: () -> !tfrt_tensor.tensor
%ch1 = "tfrt_dht.set_tensor_with_constant_values.f32"(%t1, %ch0)
{ values = [1.0 : f32, 2.0 : f32, 3.0 : f32, 4.0 : f32] }
: (!t.tensor, !tfrt.chain) -> !tfrt.chain
: (!tfrt_tensor.tensor, !tfrt.chain) -> !tfrt.chain

%t2 = "tfrt_dht.create_uninitialized_tensor.f32.2"()
{ shape = [1 : i64, 4 : i64] }
: () -> !t.tensor
: () -> !tfrt_tensor.tensor
%ch2 = "tfrt_test.flatten.f32"(%t1, %t2, %ch1)
: (!t.tensor, !t.tensor, !tfrt.chain) -> !tfrt.chain
: (!tfrt_tensor.tensor, !tfrt_tensor.tensor, !tfrt.chain) -> !tfrt.chain

// CHECK: shape = [1, 4], values = [1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]
tfrt_dht.print_tensor %t2, %ch2
Expand All @@ -43,15 +43,15 @@ func.func @test_max_pool_2d_f32_padding_error() {

%input = "tfrt_dht.create_uninitialized_tensor.f32.4"()
{ shape = [2 : i64, 3 : i64, 3 : i64, 6 : i64] }
: () -> !t.tensor
: () -> !tfrt_tensor.tensor
%output = "tfrt_dht.create_uninitialized_tensor.f32.4"()
{ shape = [2 : i64, 1 : i64, 1 : i64, 6 : i64] }
: () -> !t.tensor
: () -> !tfrt_tensor.tensor

// expected-error @+1 {{padding 'invalid' is not recognized}}
"tfrt_test.max_pooling_2d.f32"(%input, %output, %ch0)
{ padding = "invalid", pool_size = [3 : i32, 3 : i32], strides = [3 : i32, 3 : i32] }
: (!t.tensor, !t.tensor, !tfrt.chain) -> !tfrt.chain
: (!tfrt_tensor.tensor, !tfrt_tensor.tensor, !tfrt.chain) -> !tfrt.chain

tfrt.return
}
Expand All @@ -62,15 +62,15 @@ func.func @test_max_pool_2d_f32_shape_error() {

%input = "tfrt_dht.create_uninitialized_tensor.f32.4"()
{ shape = [2 : i64, 3 : i64, 3 : i64, 6 : i64] }
: () -> !t.tensor
: () -> !tfrt_tensor.tensor
%output = "tfrt_dht.create_uninitialized_tensor.f32.4"()
{ shape = [2 : i64, 2 : i64, 1 : i64, 6 : i64] }
: () -> !t.tensor
: () -> !tfrt_tensor.tensor

// expected-error @+1 {{output shape [2, 2, 1, 6] does not match the expected output shape [2, 1, 1, 6]}}
"tfrt_test.max_pooling_2d.f32"(%input, %output, %ch0)
{ padding = "valid", pool_size = [3 : i32, 3 : i32], strides = [3 : i32, 3 : i32] }
: (!t.tensor, !t.tensor, !tfrt.chain) -> !tfrt.chain
: (!tfrt_tensor.tensor, !tfrt_tensor.tensor, !tfrt.chain) -> !tfrt.chain

tfrt.return
}
5 changes: 2 additions & 3 deletions include/tfrt/tensor/opdefs/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@ using namespace mlir;
namespace tfrt {
namespace tfrt_tensor {

// TODO (b/341154040): Pass in "tfrt_tensor" instead of "t".
class TensorDialect : public Dialect {
public:
static StringRef getDialectNamespace() { return "t"; }
static StringRef getDialectNamespace() { return "tfrt_tensor"; }
explicit TensorDialect(MLIRContext *context);

Type parseType(DialectAsmParser &parser) const override;
Expand All @@ -43,7 +42,7 @@ class TensorDialect : public Dialect {
class TensorType : public Type::TypeBase<TensorType, Type, TypeStorage> {
public:
using Base::Base;
static constexpr StringLiteral name = "tfrt.t.tensor";
static constexpr StringLiteral name = "tfrt.tfrt_tensor.tensor";
};

} // namespace tfrt_tensor
Expand Down
4 changes: 1 addition & 3 deletions lib/tensor/opdefs/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@ namespace tfrt_tensor {
// TensorShape Dialect
//===----------------------------------------------------------------------===//

// TODO (b/341154040): Pass in "tfrt_tensor" into the Dialect constructor
// instead of "t".
TensorDialect::TensorDialect(MLIRContext *context)
: Dialect(/*name=*/"t", context, TypeID::get<TensorDialect>()) {
: Dialect(/*name=*/"tfrt_tensor", context, TypeID::get<TensorDialect>()) {
allowUnknownTypes();
addTypes<TensorType>();
addOperations<
Expand Down
2 changes: 1 addition & 1 deletion mlir_tests/tensor/string_host_tensor.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func.func @basic() {
%c0 = tfrt.new.chain

%a = "tfrt_sht.create_tensor"()
{shape = [2], values = ["string", "tensor"]} : () -> !t.tensor
{shape = [2], values = ["string", "tensor"]} : () -> !tfrt_tensor.tensor

// CHECK: shape = [2], values = ["string", "tensor"]
%c1 = tfrt_dht.print_tensor %a, %c0
Expand Down

0 comments on commit 4662f75

Please sign in to comment.