Skip to content

Commit

Permalink
Improvements, debugging, visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
lnotspotl committed May 4, 2024
1 parent 4637c6a commit b4fc891
Show file tree
Hide file tree
Showing 13 changed files with 864 additions and 58 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.vscode
__pycache__/
out/
out/
libtorch
35 changes: 30 additions & 5 deletions tbai_bindings/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
cmake_minimum_required(VERSION 3.0.2)
project(tbai_bindings)

set(Torch_DIR /home/kuba/Downloads/libtorch/share/cmake/Torch)
set(LIBTORCH_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../dependencies/libtorch)
set(LIBTORCH_INCLUDE_DIR ${LIBTORCH_DIR}/include)
set(LIBTORCH_LIB_DIR ${LIBTORCH_DIR}/lib)
set(LIBTORCH_CMAKE_DIR ${LIBTORCH_DIR}/share/cmake/Torch)

find_package(catkin REQUIRED COMPONENTS
ocs2_oc
Expand All @@ -11,6 +14,9 @@ find_package(catkin REQUIRED COMPONENTS
ocs2_legged_robot
ocs2_legged_robot_ros
pybind11_catkin
ocs2_ros_interfaces
ocs2_msgs
message_generation
)
set(CMAKE_CXX_STANDARD 17)

Expand All @@ -23,12 +29,27 @@ find_package(pinocchio REQUIRED)
find_package (Python3 3.8.18 COMPONENTS Interpreter Development)
find_package(pybind11 REQUIRED)

find_package(Torch REQUIRED)
set(Torch_DIR ${LIBTORCH_CMAKE_DIR})
find_package(Torch)
if (NOT Torch_FOUND)
message(FATAL_ERROR "PyTorch Not Found! Move your libtorch to the ${LIBTORCH_DIR} folder.")
endif()
set(OTHER_LIBRARIES ${LIBTORCH_LIB_DIR}/libtorch_python.so)

add_message_files(
FILES
bindings_visualize.msg
)

generate_messages(
DEPENDENCIES
ocs2_msgs
)

catkin_package(
INCLUDE_DIRS include
LIBRARIES ${PROJECT_NAME}
CATKIN_DEPENDS ocs2_oc ocs2_mpc ocs2_robotic_tools ocs2_centroidal_model ocs2_legged_robot ocs2_legged_robot_ros
CATKIN_DEPENDS ocs2_oc ocs2_mpc ocs2_robotic_tools ocs2_centroidal_model ocs2_legged_robot ocs2_legged_robot_ros ocs2_msgs message_runtime
DEPENDS
)

Expand All @@ -42,8 +63,6 @@ include_directories(
${pinocchio_INCLUDE_DIRS}
)

# Apparently, this is every important!!!
set(OTHER_LIBRARIES /home/kuba/Downloads/libtorch/lib/libtorch_python.so)

add_library(${PROJECT_NAME}
src/Utils.cpp
Expand All @@ -60,6 +79,12 @@ target_link_libraries(ig_interface PRIVATE
${OTHER_LIBRARIES}
)

add_executable(main src/main.cpp)
target_link_libraries(main ${PROJECT_NAME})

add_executable(visualizer src/visualizer.cpp)
target_link_libraries(visualizer ${PROJECT_NAME})

# Install the python module
install(TARGETS ig_interface
ARCHIVE DESTINATION ${CMAKE_SOURCE_DIR}/out
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#pragma once

#ifdef TBAI_BINDINGS_DISABLE_ASSERTS
#define TBAI_BINDINGS_ASSERT(condition, message)
#else
Expand All @@ -13,4 +15,23 @@
} \
} while (0)

#endif
#endif

#ifdef TBAI_BINDINGS_DISABLE_PRINTS
#define TBAI_BINDINGS_PRINT(message)
#else
#define TBAI_BINDINGS_PRINT(message) std::cout << "[Tbai bindings] | " << message << std::endl
#endif

#define TBAI_BINDINGS_STD_THROW(message) \
do { \
std::cerr << "\n" \
<< "Exception thrown in file " << __FILE__ << " at line " << __LINE__ << ": " << message << "\n" \
<< std::endl; \
throw std::runtime_error(message); \
} while (0)

#define TBAI_BINDINGS_STD_THROW_IF(condition, message) \
if (condition) { \
TBAI_BINDINGS_STD_THROW(message); \
}
36 changes: 29 additions & 7 deletions tbai_bindings/include/tbai_bindings/TbaiIsaacGymInterface.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#pragma once

// clang-format off
#include <pinocchio/fwd.hpp>
// clang-format on

#include <chrono>
#include <memory>
#include <string>
Expand All @@ -22,6 +26,9 @@
#include <torch/extension.h>
#include <torch/torch.h>

#include <ros/ros.h>
#include <tbai_bindings/bindings_visualize.h>

namespace tbai {
namespace bindings {

Expand All @@ -34,7 +41,7 @@ class TbaiIsaacGymInterface {
public:
TbaiIsaacGymInterface(const std::string &taskFile, const std::string &urdfFile, const std::string &referenceFile,
const std::string &gaitFile, const std::string &gaitName, int numEnvs, int numThreads,
torch::Device device = torch::kCPU);
torch::Device device = torch::kCPU, bool visualize = false);

/** Public interface **/
void resetAllSolvers(scalar_t time);
Expand Down Expand Up @@ -70,9 +77,17 @@ class TbaiIsaacGymInterface {
torch::Tensor &getDesiredBaseLinearAccelerations() { return desiredBaseLinearAccelerations_; }
torch::Tensor &getDesiredBaseAngularAccelerations() { return desiredBaseAngularAccelerations_; }

void visualize(scalar_t time, torch::Tensor &state, int envId, torch::Tensor &obs);

PrimalSolution getCurrentOptimalTrajectory(int envId) const;
SystemObservation getCurrentObservation(scalar_t time, int envId) const;

// Move relevant tensors to CPU and convert them to Eigen data types
void toCpu();

// Move relevant tensors to GPU
void toGpu();

private:
void allocateInterfaceBuffers();
void allocateEigenBuffers();
Expand All @@ -92,17 +107,20 @@ class TbaiIsaacGymInterface {
int numEnvs_;
int numThreads_;

matrix_t currentStates_;
matrix_t currentCommands_;
matrix_t currentStatesCpu_;
matrix_t currentCommandsCpu_;
torch::Tensor optimizedStates_;
torch::Tensor consistencyRewards_;
std::vector<PrimalSolution> solutions_;

std::vector<std::unique_ptr<LeggedRobotInterface>> interfacePtrs_;
std::vector<std::unique_ptr<SqpSolver>> solvers_;
std::vector<std::unique_ptr<PinocchioInterface>> pinocchioInterfaces_;
std::vector<std::unique_ptr<PinocchioEndEffectorKinematics>> endEffectorKinematics_;
std::vector<std::unique_ptr<CentroidalModelPinocchioMapping>> centroidalModelMappings_;
std::vector<std::unique_ptr<SqpSolver>> solverPtrs_;
std::vector<std::unique_ptr<PinocchioInterface>> pinocchioInterfacePtrs_;
std::vector<std::unique_ptr<PinocchioEndEffectorKinematics>> endEffectorKinematicsPtrs_;
std::vector<std::unique_ptr<CentroidalModelPinocchioMapping>> centroidalModelMappingPtrs_;

torch::Tensor currentStates_;
torch::Tensor currentCommands_;

torch::Tensor desiredContacts_;
torch::Tensor timeLeftInPhase_;
Expand Down Expand Up @@ -131,10 +149,14 @@ class TbaiIsaacGymInterface {

vector_t initialState_;

bool visualize_;

// MPC horizon in seconds
scalar_t horizon_;

std::unique_ptr<ocs2::legged_robot::ModeSequenceTemplate> modeSequenceTemplate_;

ros::Publisher pub_;
};

} // namespace bindings
Expand Down
Loading

0 comments on commit b4fc891

Please sign in to comment.