From 38ed80e8a918cf060704420b93d41beee598346a Mon Sep 17 00:00:00 2001 From: mcuiaws Date: Fri, 20 Dec 2024 14:55:46 -0800 Subject: [PATCH] When modifying IR node, make sure to not lose the read_only bit (#8505) (#8508) --- test/test_input_output_aliases.py | 26 ++++++++++++++++++++++++++ torch_xla/csrc/tensor.cpp | 6 +++++- torch_xla/csrc/xla_graph_executor.cpp | 6 +++++- 3 files changed, 36 insertions(+), 2 deletions(-) diff --git a/test/test_input_output_aliases.py b/test/test_input_output_aliases.py index df21a1ad5e3..6fe28babf72 100644 --- a/test/test_input_output_aliases.py +++ b/test/test_input_output_aliases.py @@ -185,6 +185,32 @@ def test_xm_save_no_aliasing(self): self.assertEqual(t2.item(), 3) + def test_device_data_cache_no_aliasing(self): + """ + Test that device data in DataCache are not aliased. + """ + xla_device = xm.xla_device() + + t0 = torch.tensor(42, device=xla_device) + # drops the read-only bit on t0's device_data + xm.mark_step() + + # cached value of 42 is donated + t0.add_(1) + xm.mark_step() + + # t1 get the cached device_data, which was donated + t1 = torch.tensor(42, device=xla_device) + xm.mark_step() + + t1.add_(1) + # XLA crashes here because parameter is donated buffer... + xm.mark_step() + + # ...if it doesn't crash, the value here would be 44. + self.assertEqual(t1.item(), 43) + + if __name__ == '__main__': test = unittest.main() sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 01306c53d38..324595ca6ca 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -386,7 +386,11 @@ torch::lazy::Value XLATensor::GetIrValue() const { // will still collapse them all into a single XLA parameter op). So call // which wants the XLA data will still find it, w/out having to fetch it // via a computation client from-server call. - AssignIrValue(CreateTensorNode(handle, /*read_only=*/false)); + auto* data_info = + static_cast( + handle->info()); + bool read_only = data_info != nullptr && data_info->read_only; + AssignIrValue(CreateTensorNode(handle, read_only)); // CreateTensorNode will set the data info of the tensor to the current // unique_id. Here the alias id needs to be updated so that input output // alias can correctly work on the xla's custom inplace operation. diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 81cf0207029..a0aa6e7150d 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -659,9 +659,13 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors( // XlaData from the DeviceData Node and reset the IR. We also want // to update XlaData's tensorID to make it match with the current // XLATensor. + auto* data_info = + static_cast( + device_data->data()->info()); + bool read_only = data_info != nullptr && data_info->read_only; tensors[i]->GetXlaData()->SetInfo( std::make_shared( - tensors[i]->GetUniqueId(), /*=read_only=*/false)); + tensors[i]->GetUniqueId(), read_only)); } else { // Add only tensors which need to be synced. coll.hash = torch::lazy::HashCombine(coll.hash, ir_value.hash());