Skip to content

Commit

Permalink
Move getter implementations into the header file
Browse files Browse the repository at this point in the history
  • Loading branch information
lnotspotl committed Apr 29, 2024
1 parent 1958a91 commit 4637c6a
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 94 deletions.
2 changes: 2 additions & 0 deletions CPPLINT.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ filter=-runtime/explicit
filter=-whitespace/line_length
filter=-build/include_order
filter=-runtime/references
filter=-build/header_guard
filter=-build/c++11
exclude_files=dependencies
2 changes: 1 addition & 1 deletion tbai_bindings/include/tbai_bindings/Asserts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#define TBAI_BINDINGS_ASSERT(condition, message)
#else

#define TBAI_BINDINGS_ASSERT(condition, message) \
#define TBAI_BINDINGS_ASSERT(condition, message) \
do { \
if (!(condition)) { \
std::cerr << "\n" \
Expand Down
54 changes: 24 additions & 30 deletions tbai_bindings/include/tbai_bindings/TbaiIsaacGymInterface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include <Eigen/Dense>
#include <ocs2_centroidal_model/CentroidalModelPinocchioMapping.h>
#include <ocs2_centroidal_model/PinocchioCentroidalDynamics.h>
#include <tbai_bindings/Types.hpp>
#include <ocs2_core/misc/LinearInterpolation.h>
#include <ocs2_core/reference/TargetTrajectories.h>
#include <ocs2_legged_robot/LeggedRobotInterface.h>
Expand All @@ -19,6 +18,7 @@
#include <ocs2_pinocchio_interface/PinocchioEndEffectorKinematics.h>
#include <ocs2_pinocchio_interface/PinocchioInterface.h>
#include <ocs2_sqp/SqpSolver.h>
#include <tbai_bindings/Types.hpp>
#include <torch/extension.h>
#include <torch/torch.h>

Expand All @@ -36,60 +36,54 @@ class TbaiIsaacGymInterface {
const std::string &gaitFile, const std::string &gaitName, int numEnvs, int numThreads,
torch::Device device = torch::kCPU);

/** Public interface **/
void resetAllSolvers(scalar_t time);
void resetSolvers(scalar_t time, const torch::Tensor &envIds);

void updateCurrentStates(const torch::Tensor &newStates);
void updateCurrentStates(const torch::Tensor &newStates, const torch::Tensor &envIds);

void updateOptimizedStates(scalar_t time);
void updateOptimizedStates(scalar_t time, const torch::Tensor &envIds);

void optimizeTrajectories(scalar_t time);
void optimizeTrajectories(scalar_t time, const torch::Tensor &envIds);

void setCurrentCommand(const torch::Tensor &command, const torch::Tensor &envIds);


torch::Tensor &getOptimizedStates();
torch::Tensor &getUpdatedInSeconds();
torch::Tensor &getConsistencyReward();

torch::Tensor &getPlanarFootHolds();
torch::Tensor &getDesiredJointPositions();
torch::Tensor &getDesiredContacts();
torch::Tensor &getTimeLeftInPhase();
torch::Tensor &getCurrentDesiredJointPositions();

torch::Tensor &getDesiredBasePositions();
torch::Tensor &getDesiredBaseOrientations();
torch::Tensor &getDesiredBaseLinearVelocities();
torch::Tensor &getDesiredBaseAngularVelocities();
torch::Tensor &getDesiredBaseLinearAccelerations();
torch::Tensor &getDesiredBaseAngularAccelerations();

const LeggedRobotInterface &getInterface(int i) const;
void updateDesiredContacts(scalar_t time, const torch::Tensor &envIds);
void updateTimeLeftInPhase(scalar_t time, const torch::Tensor &envIds);
void updateDesiredJointAngles(scalar_t time, const torch::Tensor &envIds);
void updateCurrentDesiredJointAngles(scalar_t time, const torch::Tensor &envIds);
void updateNextOptimizationTime(scalar_t time, const torch::Tensor &envIds);

void updateDesiredBase(scalar_t time, const torch::Tensor &envIds);
void moveDesiredBaseToGpu();

/** Getters **/
torch::Tensor &getOptimizedStates() { return optimizedStates_; }
torch::Tensor &getUpdatedInSeconds() { return updateInSeconds_; }
torch::Tensor &getConsistencyReward() { return consistencyRewards_; }
torch::Tensor &getPlanarFootHolds() { return desiredFootholds_; }
torch::Tensor &getDesiredJointPositions() { return desiredJointAngles_; }
torch::Tensor &getDesiredContacts() { return desiredContacts_; }
torch::Tensor &getTimeLeftInPhase() { return timeLeftInPhase_; }
torch::Tensor &getCurrentDesiredJointPositions() { return currentDesiredJointAngles_; }
torch::Tensor &getDesiredBasePositions() { return desiredBasePositions_; }
torch::Tensor &getDesiredBaseOrientations() { return desiredBaseOrientations_; }
torch::Tensor &getDesiredBaseLinearVelocities() { return desiredBaseLinearVelocities_; }
torch::Tensor &getDesiredBaseAngularVelocities() { return desiredBaseAngularVelocities_; }
torch::Tensor &getDesiredBaseLinearAccelerations() { return desiredBaseLinearAccelerations_; }
torch::Tensor &getDesiredBaseAngularAccelerations() { return desiredBaseAngularAccelerations_; }

PrimalSolution getCurrentOptimalTrajectory(int envId) const;
SystemObservation getCurrentObservation(scalar_t time, int envId) const;

private:
void loadModeSequenceTemplates(const std::string &gaitFile, const std::string &gaitName);

void createInterfaces(const std::string &taskFile, const std::string &urdfFile, const std::string &referenceFile);

void allocateInterfaceBuffers();
void allocateEigenBuffers();
void allocateTorchBuffers();

const LeggedRobotInterface &getInterface(int i) const;

void loadModeSequenceTemplates(const std::string &gaitFile, const std::string &gaitName);

void createInterfaces(const std::string &taskFile, const std::string &urdfFile, const std::string &referenceFile);

void updateNextOptimizationTimeImpl(scalar_t time, int envId);
scalar_t computeConsistencyReward(const PrimalSolution &previousSolution, const PrimalSolution &currentSolution);

Expand Down
63 changes: 2 additions & 61 deletions tbai_bindings/src/TbaiIsaacGymInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,49 +337,9 @@ void TbaiIsaacGymInterface::updateDesiredJointAngles(scalar_t time, const torch:
threadPool_.submit_loop(0, static_cast<int>(envIds.numel()), impl).wait();
}

torch::Tensor &TbaiIsaacGymInterface::getOptimizedStates() {
return optimizedStates_;
}

torch::Tensor &TbaiIsaacGymInterface::getUpdatedInSeconds() {
return updateInSeconds_;
}

torch::Tensor &TbaiIsaacGymInterface::getConsistencyReward() {
return consistencyRewards_;
}

torch::Tensor &TbaiIsaacGymInterface::getPlanarFootHolds() {
return desiredFootholds_;
}

torch::Tensor &TbaiIsaacGymInterface::getDesiredJointPositions() {
return desiredJointAngles_;
}

torch::Tensor &TbaiIsaacGymInterface::getDesiredContacts() {
return desiredContacts_;
}

torch::Tensor &TbaiIsaacGymInterface::getTimeLeftInPhase() {
return timeLeftInPhase_;
}

torch::Tensor &TbaiIsaacGymInterface::getCurrentDesiredJointPositions() {
return currentDesiredJointAngles_;
}

const LeggedRobotInterface &TbaiIsaacGymInterface::getInterface(int i) const {
// Check that index is within bounds
if (i < 0 || i >= numEnvs_) {
throw std::runtime_error("Index out of bounds");
}

// Check that interface is initialized
if (interfacePtrs_[i] == nullptr) {
throw std::runtime_error("Interface not initialized");
}

TBAI_BINDINGS_ASSERT(i >= 0 && i < numEnvs_, "Index out of bounds");
TBAI_BINDINGS_ASSERT(interfacePtrs_[i] != nullptr, "Interface not initialized");
return *interfacePtrs_[i];
}

Expand Down Expand Up @@ -585,24 +545,5 @@ void TbaiIsaacGymInterface::moveDesiredBaseToGpu() {
desiredBaseAngularAccelerations_ = tbai::bindings::matrix2torch(desiredBaseAngularAccelerationsCpu_).to(device_);
}

torch::Tensor &TbaiIsaacGymInterface::getDesiredBasePositions() {
return desiredBasePositions_;
}
torch::Tensor &TbaiIsaacGymInterface::getDesiredBaseOrientations() {
return desiredBaseOrientations_;
}
torch::Tensor &TbaiIsaacGymInterface::getDesiredBaseLinearVelocities() {
return desiredBaseLinearVelocities_;
}
torch::Tensor &TbaiIsaacGymInterface::getDesiredBaseAngularVelocities() {
return desiredBaseAngularVelocities_;
}
torch::Tensor &TbaiIsaacGymInterface::getDesiredBaseLinearAccelerations() {
return desiredBaseLinearAccelerations_;
}
torch::Tensor &TbaiIsaacGymInterface::getDesiredBaseAngularAccelerations() {
return desiredBaseAngularAccelerations_;
}

} // namespace bindings
} // namespace tbai
1 change: 0 additions & 1 deletion tbai_bindings/src/TbaiIsaacGymInterfaceBindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ PYBIND11_MODULE(ig_interface, m) {
.def("get_time_left_in_phase", &TbaiIsaacGymInterface::getTimeLeftInPhase)
.def("get_current_observation", &TbaiIsaacGymInterface::getCurrentObservation)
.def("get_current_optimal_trajectory", &TbaiIsaacGymInterface::getCurrentOptimalTrajectory)
.def("get_interface", &TbaiIsaacGymInterface::getInterface, py::return_value_policy::reference)
.def("update_current_desired_joint_angles", &TbaiIsaacGymInterface::updateCurrentDesiredJointAngles)
.def("get_current_desired_joint_positions", &TbaiIsaacGymInterface::getCurrentDesiredJointPositions)
.def("get_desired_base_positions", &TbaiIsaacGymInterface::getDesiredBasePositions)
Expand Down
2 changes: 1 addition & 1 deletion tbai_bindings/test/testUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ TEST(Combined, RoundTripEig2Torch2Eig) {

for (int i = 0; i < m.rows(); i++) {
for (int j = 0; j < m.cols(); j++) {
EXPECT_NEAR(m(i, j), m2(i,j), 1e-6);
EXPECT_NEAR(m(i, j), m2(i, j), 1e-6);
}
}
}
Expand Down

0 comments on commit 4637c6a

Please sign in to comment.