Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve documentation for virtual functions #551

Merged
merged 20 commits into from
Oct 1, 2024
Merged
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 145 additions & 105 deletions docs/source/ProgrammingGuide/Kokkos-and-Virtual-Functions.md
Original file line number Diff line number Diff line change
@@ -1,184 +1,224 @@
# 14. Kokkos and Virtual Functions

```{warning}
Using virtual functions in parallel regions is not a good idea in general. It often degrades performance, requires specific code for a correct execution on GPU, and is not portable on every backend. We recommend to use a different approach whenever possible.
```

Due to oddities of GPU programming, the use of virtual functions in Kokkos parallel regions can be complicated. This document describes the problems you're likely to face, where they come from, and how to work around them.

Please note that virtual functions can be executed on the device for the following backends:

- Cuda; and
- HIP (with a limitation, as explained at the end).

Especially, SYCL 2020 [cannot handle virtual functions](https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#_overview).

## The Problem

In GPU programming, you might have run into the bug of calling a host function from the device. A similar thing can happen for subtle reasons in code using virtual functions. Consider the following code
In GPU programming, you might have run into the bug of calling a host function from the device. A similar thing can happen for subtle reasons in code using virtual functions. Consider the following serial code

```c++
class Base {
public:
void Foo() {}

virtual void Bar() {}
};

class Derived : public Base {
/** fields */
public:
KOKKOS_FUNCTION virtual void Bar(){
// TODO: implement all of physics
}
void Bar() override {}
pzehner marked this conversation as resolved.
Show resolved Hide resolved
};

Base* hostClassInstance = new Derived();
Base* deviceClassInstance;
cudaMalloc((void**)&deviceClassInstance, sizeof(Derived));
cudaMemcpy(deviceClassInstance, hostClassInstance, sizeof(Derived), cudaMemcpyHostToDevice);

Kokkos::parallel_for("DeviceKernel", SomeCudaPolicy, KOKKOS_LAMBDA(const int i) {
deviceClassInstance->Bar();
});
int main(int argc, char *argv[])
pzehner marked this conversation as resolved.
Show resolved Hide resolved
// create
Base* instance = new Derived();

// use
for (int i = 0; i < 10; i++) {
instance->Bar();
}

// cleanup
delete instance;
}
```

At a glance this should be fine, we've made a device instance of a class, copied the contents of a host instance into it, and then used it. This code will typically crash, however, because `virtualFunction` will call a host version of the function. To understand why, you'll need to understand a bit about how virtual functions are implemented.
This code is more complex to port on GPU than it looks like.
Using a straightforward approach, we would replace the `for` loop with `parallel_for` and copy `instance` on the GPU memory (not disclosing how for now).
pzehner marked this conversation as resolved.
Show resolved Hide resolved
Then, we would call `Bar()` inside the `parallel_for`.
At a glance this should be fine, but it will typically crash, however, because `instance` will call a host version of `Bar()`.
To understand why, we need to understand a bit about how virtual functions are implemented.

## V-Tables, V-Pointers, V-ery annoying with GPUs
## Vtables, Vpointers, Very annoying with GPUs

Virtual functions allow a program to handle Derived classes through a pointer to their Base class and have things work as they should. To make this work, the compiler needs some way to identify whether a pointer which is nominally to a Base class really is a pointer to the Base, or whether it's really a pointer to any Derived class. This happens through Vpointers and Vtables. For every class with virtual functions, there is one Vtable shared among all instances, this table contains function pointers for all the virtual functions the class implements.
Virtual functions allow a program to handle Derived classes through a pointer to their Base class and have things work as they should. To make this work, the compiler needs some way to identify whether a pointer which is nominally to a Base class really is a pointer to the Base, or whether it's really a pointer to any Derived class. This happens through Vpointers and Vtables. For every class with virtual functions, there is one Vtable shared among all instances, this table contains function pointers for all the virtual functions the class implements.

![VTable](./figures/VirtualFunctions-VTables.png)

Okay, so now we have Vtables, if a class knows what type it is it could call the correct function. But how does it know?
Okay, so now we have Vtables, if a class knows what type it is, it could call the correct function. But how does it know?

Remember that we have one Vtable shared amongst all instances of a type. Each instance, however, has a hidden member called the Vpointer, which the compiler points at construction to the correct Vtable. So a call to a virtual function simply dereferences that pointer, and then indexes into the Vtable to find the precise virtual function called.

![VPointer](./figures/VirtualFunctions-VPointers.png)

Now that we know what the compiler is doing to implement virtual functions, we'll look at why it doesn't work with GPU's
Now that we know what the compiler is doing to implement virtual functions, we'll look at why it doesn't work with GPU's.

Credit: the content of this section is adapted from Pablo Arias [here](https://pabloariasal.github.io/2017/06/10/understanding-virtual-tables/)
Credit: the content of this section is adapted from Pablo Arias [here](https://pabloariasal.github.io/2017/06/10/understanding-virtual-tables/).

## Then why doesn't my code work?
## Then why the straightforward approach doesn't work?
pzehner marked this conversation as resolved.
Show resolved Hide resolved

The reason the intro code might break is that when dealing with GPU-compatible classes with virtual functions, there isn't one Vtable, but two. The first holds the host versions of the virtual functions, while the second holds the device functions.
The reason the straightforward approach described above fails is that when dealing with GPU-compatible classes with virtual functions, there isn't one Vtable, but two. The first holds the host version of the virtual functions, while the second holds the device functions.

![VTableDevice](./figures/VirtualFunctions-VTablesHostDevice.png)

Since we construct the class instance on the host, so it's Vpointer points to the host Vtable.
Since we construct the class instance on the host, its Vpointer points to the host Vtable.

![VPointerToHost](./figures/VirtualFunctions-VPointerToHost.png)

Our cudaMemcpy faithfully copied all of the members of the class, including the Vpointer merrily pointing at host functions, which we then call on the device.
We faithfully copied all of the members of the class on the GPU memory, including the Vpointer happily pointing at host functions, which we then call on the device.

## How to fix this
## Make it work

The problem here is that we are constructing the class on the Host. If we were constructing on the Device, we'd get the correct Vpointer, and thus the correct functions (but only for calls on the device). In pseudocode, we want to move from
The problem here is that we are constructing the instance on the Host.
pzehner marked this conversation as resolved.
Show resolved Hide resolved
If we were constructing it on the Device, we'd get the correct Vpointer, and thus the correct functions.
pzehner marked this conversation as resolved.
Show resolved Hide resolved
Note that this would allow to call virtual functions on the device only, not on the host anymore.

```c++
Base* hostInstance = new Derived(); // allocate and initialize host
Base* deviceInstance; // cudaMalloc'd to allocate
cudaMemcpy(deviceInstance, hostInstance); // to initialize the deivce
Kokkos::parallel_for(... {
// use deviceInstance
});
```
To that aim, we first allocate memory on the device, then construct on the device using a technique called *placement new*
pzehner marked this conversation as resolved.
Show resolved Hide resolved

To one where we construct on the device using a technique called `placement new`
```cpp
#include <Kokkos_Core.hpp>

```c++
Base* deviceInstance; // cudaMalloc'd to allocate it
Kokkos::parallel_for(... {
new((Derived*)deviceInstance) Derived(); // construct an instance in the place, the pointer deviceInstance points to
});
class Base {
public:
void Foo() {}

KOKKOS_FUNCTION
virtual void Bar() {}
};

class Derived : public Base {
public:
KOKKOS_FUNCTION
void Bar() override {}
};

int main(int argc, char *argv[])
{
Kokkos::ScopeGuard kokkos(argc, argv);

// create
auto deviceInstanceMemory = Kokkos::kokkos_malloc(sizeof(Derived)); // allocate memory on device
pzehner marked this conversation as resolved.
Show resolved Hide resolved
Kokkos::parallel_for("initialize", 1, KOKKOS_LAMBDA (const int i) {
new (static_cast<Derived*>(deviceInstanceMemory)) Derived(); // initialize on device
});
Kokkos::fence("wait for initialize");
pzehner marked this conversation as resolved.
Show resolved Hide resolved
Base* deviceInstance = static_cast<Derived*>(deviceInstanceMemory); // declare on this memory

// use
Kokkos::parallel_for("myKernel", 10, KOKKOS_LAMBDA (const int i) {
deviceInstance->Bar();
});
Kokkos::fence();
pzehner marked this conversation as resolved.
Show resolved Hide resolved

// cleanup
Kokkos::parallel_for("destroy", 1, KOKKOS_LAMBDA (const int i) {
deviceInstance->~Base(); // destroy on device
});
Kokkos::fence("wait for destroy");
pzehner marked this conversation as resolved.
Show resolved Hide resolved
Kokkos::kokkos_free(deviceInstanceMemory); // free
}
```

We first use the `KOKKOS_FUNCTION` macro to make the methods callable from a kernel.
When creating the instance, note that we introduce a distinction between the *memory* that the it uses, and the actual instantiated *object*.
The construct is done on the device, within a single-iteration `parallel_for`, using placement new.
pzehner marked this conversation as resolved.
Show resolved Hide resolved
Since the kernel does not have a return type, we use a static cast to associate the object with the memory allocation.
pzehner marked this conversation as resolved.
Show resolved Hide resolved

Like with other uses of `new`, we need to later free the memory.
The destruct is done on the device, again with a single-iteration `parallel_for`.
pzehner marked this conversation as resolved.
Show resolved Hide resolved

This code is extremely ugly, but leads to functional virtual function calls on the device. The Vpointer now points to the device Vtable.
Remember that those virtual functions cannot be called on the host anymore!

![VPointerToDevice](./figures/VirtualFunctions-VPointerToDevice.png)

Note that like with other uses of `new`, you need to later `free` the memory.
For a full working example, see [the example in the repo](https://github.com/kokkos/kokkos/blob/master/example/virtual_functions/main.cpp).

## Complications and Fixes

The first problem people run into with this is that they want to initialize some fields or nested classes based on host data before moving data down to the device

```c++
Base* hostInstance = new Derived(); // allocate and initialize host
hostInstance->setAField(someHostValue);
Base* deviceInstance; // cudaMalloc'd to allocate
cudaMemcpy(deviceInstance, hostInstance); // to initialize the deivce
Kokkos::parallel_for(... {
// use deviceInstance
});
```
## What if I need a setter that works with host values?

We can't translate this easily, the naive translation would be
The first problem people run into with this is that they want to initialize some fields based on host data, with a setter which is *not* a virtual function.
Calling this setter on the device would crash if the host data is not easily copyable on the device (e.g. for a small array).
pzehner marked this conversation as resolved.
Show resolved Hide resolved
The most productive solution we've found in these cases is to allocate the instance in `SharedSpace`, initialize it on the device, and then fill in fields on the host
pzehner marked this conversation as resolved.
Show resolved Hide resolved

```c++
Base* deviceInstance; // cudaMalloc'd to allocate it
Kokkos::parallel_for(... {
new((Derived*)deviceInstance) Derived(); // initialize an instance, and place the result in the pointer deviceInstance
deviceInstance->setAField(someHostValue);
});
// create
auto deviceInstanceMemory = Kokkos::kokkos_malloc<Kokkos::SharedSpace>(sizeof(Derived)); // allocate on shared space
pzehner marked this conversation as resolved.
Show resolved Hide resolved
// ...
deviceInstance->setAField(someHostValue); // set on host
```

Which would crash for accessing the host value `someHostValue` on the device (or this value would need to be copied into the `parallel_for`). The most productive solution we've found in these cases is to allocate the class in `SharedSpace`, initialize it on the device, and then fill in fields on the host. To wit:
The setter is still called on the host.
Beware that this is only valid for backends that support `SharedSpace`.

```c++
Base* deviceInstance = Kokkos::kokkos_malloc<Kokkos::SharedSpace>(sizeof(Derived));
Kokkos::parallel_for(... {
new((Derived*)deviceInstance) Derived(); // construct an instance in the place the the pointer deviceInstance points to
});
deviceInstance->setAField(someHostValue); // set some field on the host
```
## But what if I do not really need the Vtables on the device side?

This is the solution that the code teams we have talked to have said is the most productive way to solve the problem. Nevertheless, it should be kept in mind, that this restricts virtual function calls to the device.
Consider the following example which calls the virtual function `Bar()` on the device from a pointer of derived class type.
One might think this should work because no Vtable lookup on the device is necessary.

## But what if I do not really need the V-Tables on the device side?
Consider the following example which calls the `virtual Bar()` on the device from a pointer of derived class type.
One might think this should work because no V-Table lookup on the device is neccessary.
```c++
#include <Kokkos_Core.hpp>
#include <cstdio>

struct Base
{
KOKKOS_DEFAULTED_FUNCTION
virtual ~Base() = default;

KOKKOS_FUNCTION
virtual void Bar() const = 0;
class Base {
public:
KOKKOS_FUNCTION
virtual void Bar() const = 0;
};

struct Derived : public Base
{
KOKKOS_FUNCTION
void Bar() const override
{ printf("Hello from Derived\n"); }

void apply(){
Kokkos::parallel_for("myLoop",10,
KOKKOS_CLASS_LAMBDA (const size_t i) { this->Bar(); }
);
}
class Derived : Base {
public:
KOKKOS_FUNCTION
void Bar() const override {
Kokkos::printf("Hello from Derived\n");
}

void apply() {
Kokkos::parallel_for("myLoop", 10,
KOKKOS_CLASS_LAMBDA (const size_t i) { this->Bar(); }
);
}
};

int main (int argc, char *argv[])
int main(int argc, char *argv[])
{
Kokkos::initialize(argc,argv);
{
auto derivedPtr = std::make_shared<Derived>();
derivedPtr->apply();
Kokkos::fence();
}
Kokkos::finalize();
Kokkos::ScopeGuard kokkos(argc, argv);
pzehner marked this conversation as resolved.
Show resolved Hide resolved

auto derivedPtr = std::make_shared<Derived>();
pzehner marked this conversation as resolved.
Show resolved Hide resolved
derivedPtr->apply();
Kokkos::fence();
}
```

### Why is this not portable?

Inside the `parallel_for` `Bar()` is called. As `Derived` derives from the pure virtual class `Base`, the 'Bar()' function is marked `override`.
On ROCm 5.2 this results in a memory access violation.
When executing the `this->Bar()` call, the runtime looks into the V-Table and dereferences a host function pointer on the device.
Inside the `parallel_for`, `Bar()` is called. As `Derived` derives from the pure virtual class `Base`, the `Bar()` function is marked `override`.
On ROCm 6.0 this results in a memory access violation.
pzehner marked this conversation as resolved.
Show resolved Hide resolved
When executing the `this->Bar()` call, the runtime looks into the Vtable and dereferences a host function pointer on the device.

### But if that is the case, why does it work with NVCC?

Notice, that the `parallel_for` is called from a pointer of type `Derived` and not a pointer of type `Base` pointing to an `Derived` object.
Thus, no V-Table lookup for the `Bar()` would be necessary as it can be deduced from the context of the call that it will be `Derived::Bar()`.
Thus, no Vtable lookup for the `Bar()` would be necessary as it can be deduced from the context of the call that it will be `Derived::Bar()`.
But here it comes down to how the compiler handles the lookup. NVCC understands that the call is coming from an `Derived` object and thinks: "Oh, I see, that you are calling from an `Derived` object, I know it will be the `Bar()` in this class scope, I will do this for you".
ROCm, on the other hand, sees your call and thinks Oh, this is a call to a virtual method, I will look that up for you” - failing to dereference the host function pointer in the host virtual function table.
ROCm, on the other hand, sees the call and thinks "Oh, this is a call to a virtual method, I will look that up for you", failing to dereference the host function pointer in the host virtual function table.

### How to solve this?
Strictly speaking, the observed behavior on NVCC is an optimization that uses the context information to avoid the V-Table lookup.
Strictly speaking, the observed behavior on NVCC is an optimization that uses the context information to avoid the Vtable lookup.
If the compiler does not apply this optimization, you can help in different ways by providing additional information. Unfortunately, none of these strategies is fully portable to all backends.

- Tell the compiler not to look up any function name in the V-Table when calling `Bar()` by using [qualified name lookup](https://en.cppreference.com/w/cpp/language/qualified_lookup). For this, you tell the compiler which function you want by spelling out the class scope in which the function should be found e.g. `this->Derived::Bar();`. This behavior is specified in the C++ Standard. Nevertheless, some backends are not fully compliant to the Standard.
- Changing the `override` to `final` on the `Bar()` in the `Derived` class. This tells the compiler `Bar()` is not changing in derived objects. Many compilers do use this in optimization and deduce which function to call without the V-Table. Nevertheless, this might only work with certain compilers, as this effect of adding `final` is not specified in the C++ Standard.
- Tell the compiler not to look up any function name in the Vtable when calling `Bar()` by using [qualified name lookup](https://en.cppreference.com/w/cpp/language/qualified_lookup). For this, you tell the compiler which function you want by spelling out the class scope in which the function should be found e.g. `this->Derived::Bar();`. This behavior is specified in the C++ Standard. Nevertheless, some backends are not fully compliant to the Standard.
- Changing the `override` to `final` on the `Bar()` in the `Derived` class. This tells the compiler `Bar()` is not changing in derived objects. Many compilers do use this in optimization and deduce which function to call without the Vtable. Nevertheless, this might only work with certain compilers, as this effect of adding `final` is not specified in the C++ Standard.
- Similarly, the entire derived class `Implementation` can be marked `final`. This is compiler dependent too, for the same reasons.

## Questions/Follow-up
Expand Down
Loading