Skip to content

Commit

Permalink
[environ] move environ overwrite to create_asyncio
Browse files Browse the repository at this point in the history
  • Loading branch information
botbw committed Oct 9, 2024
1 parent e28a80a commit c2f1349
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 26 deletions.
16 changes: 15 additions & 1 deletion csrc/backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,25 @@ bool probe_backend(const std::string &backend)
}
}

AsyncIO *create_asyncio(unsigned int n_entries, const std::string &backend)
std::string get_default_backend() {
const char* env = getenv("TENSORNVME_BACKEND");
if (env == nullptr) {
return std::string("");
}
return std::string(env);
}

AsyncIO *create_asyncio(unsigned int n_entries, std::string backend)
{
std::unordered_set<std::string> backends = get_backends();
if (backends.empty())
throw std::runtime_error("No asyncio backend is installed");

std::string default_backend = get_default_backend();
if (default_backend.size() > 0) {
backend = default_backend;
}

if (backends.find(backend) == backends.end())
throw std::runtime_error("Unsupported backend: " + backend);
if (!probe_backend(backend))
Expand Down
24 changes: 2 additions & 22 deletions csrc/offload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,29 +26,9 @@ iovec *tensors_to_iovec(const std::vector<at::Tensor> &tensors)
return iovs;
}

std::string Offloader::get_default_backend() {
const char* env = getenv("TENSORNVME_BACKEND");
if (env == nullptr) {
return std::string("");
}
return std::string(env);
}

Offloader::Offloader(const std::string &filename, unsigned int n_entries, const std::string &backend) : filename(filename), space_mgr(SpaceManager(0))
{
std::string default_backend = get_default_backend();
if (default_backend.size() > 0) {
if (get_backends().count(default_backend) == 0) {
throw std::runtime_error("Cannot find backend: " + default_backend + ", please check if TENSORNVME_BACKEND is set correctly");
}
this->aio = create_asyncio(n_entries, default_backend);
} else {
if (get_backends().count(backend) == 0) {
throw std::runtime_error("Cannot find backend: " + backend + ", please check the passed backend is set correctly");
}
this->aio = create_asyncio(n_entries, backend);
}

{
this->aio = create_asyncio(n_entries, backend);
this->fd = open(filename.c_str(), O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
this->aio->register_file(fd);
}
Expand Down
5 changes: 4 additions & 1 deletion include/backend.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
#include "asyncio.h"
#include <string>
#include <unordered_set>
#include <cstdlib>

std::unordered_set<std::string> get_backends();

bool probe_backend(const std::string &backend);

AsyncIO *create_asyncio(unsigned int n_entries, const std::string &backend);
std::string get_default_backend();

AsyncIO *create_asyncio(unsigned int n_entries, std::string backend);
2 changes: 0 additions & 2 deletions include/offload.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include "aio.h"
#endif

#include <cstdlib>
class Offloader
{
public:
Expand All @@ -32,7 +31,6 @@ class Offloader
void async_readv(const std::vector<at::Tensor> &tensors, const std::string &key, callback_t callback = nullptr);
void sync_writev(const std::vector<at::Tensor> &tensors, const std::string &key);
void sync_readv(const std::vector<at::Tensor> &tensors, const std::string &key);
static std::string get_default_backend();
private:
const std::string filename;
int fd;
Expand Down

0 comments on commit c2f1349

Please sign in to comment.