diff --git a/CMakeLists.txt b/CMakeLists.txt index 3dbcacd606..50c2adf1bf 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -94,6 +94,11 @@ include(metis) include(gmp) include(probabilistic_quadrics) +# Dependencies for images +include(stb) +include(tinyexr) + + include(lagrange) lagrange_include_modules(bvh) @@ -134,7 +139,12 @@ wmtk_target_link_system_libraries(wildmeshing_toolkit PUBLIC probabilistic_quadrics::probabilistic_quadrics paraviewo::paraviewo gmp::gmp -) + + + miniz # MTAO: I had a build issue with windows not finding miniz at linktime - adding here to make sure it's there? + tinyexr + stb::image + ) include(nlohmann_json) diff --git a/components/wmtk_components/isotropic_remeshing/internal/IsotropicRemeshing.cpp b/components/wmtk_components/isotropic_remeshing/internal/IsotropicRemeshing.cpp index ab62facb06..5740f47e1a 100644 --- a/components/wmtk_components/isotropic_remeshing/internal/IsotropicRemeshing.cpp +++ b/components/wmtk_components/isotropic_remeshing/internal/IsotropicRemeshing.cpp @@ -57,6 +57,8 @@ IsotropicRemeshing::IsotropicRemeshing(TriMesh& mesh, const double length, const op_settings.smooth_settings.position = m_position_handle; op_settings.smooth_settings.smooth_boundary = !m_lock_boundary; + op_settings.smooth_settings.initialize_invariants(m_mesh); + m_scheduler.add_operation_type( "smooth", op_settings); diff --git a/src/wmtk/CMakeLists.txt b/src/wmtk/CMakeLists.txt index 4391ef83f8..57e481bc6d 100644 --- a/src/wmtk/CMakeLists.txt +++ b/src/wmtk/CMakeLists.txt @@ -32,12 +32,20 @@ set(SRC_FILES ) target_sources(wildmeshing_toolkit PRIVATE ${SRC_FILES}) + + + add_subdirectory(io) add_subdirectory(utils) add_subdirectory(attribute) add_subdirectory(simplex) add_subdirectory(operations) +add_subdirectory(optimization) add_subdirectory(autogen) + add_subdirectory(invariants) add_subdirectory(multimesh) +add_subdirectory(function) +add_subdirectory(image) + diff --git a/src/wmtk/Scheduler.cpp b/src/wmtk/Scheduler.cpp index b79ad48fa6..f0b01cff12 100644 --- a/src/wmtk/Scheduler.cpp +++ b/src/wmtk/Scheduler.cpp @@ -27,6 +27,7 @@ void Scheduler::run_operation_on_all(PrimitiveType type, const std::string& name // enqueue_operations(ops); // TODO: pick some strategy for running these operations // tbb::parallel_for(ops, [&](const auto& ops) { (*op)(); }); + spdlog::info("Ran {} [{}] ops, {} succeeded, {} failed", number_of_performed_operations(), name, number_of_successful_operations(), number_of_failed_operations()); } void Scheduler::enqueue_operations(std::vector>&& ops) diff --git a/src/wmtk/Scheduler.hpp b/src/wmtk/Scheduler.hpp index 8ea4503c90..cb23fe3a93 100644 --- a/src/wmtk/Scheduler.hpp +++ b/src/wmtk/Scheduler.hpp @@ -5,7 +5,6 @@ #include "operations/Operation.hpp" #include "operations/OperationFactory.hpp" - namespace wmtk { // Scheduler scheduler; @@ -31,20 +30,37 @@ class Scheduler // primitive_type, // std::forward(args)...); //} + + const operations::OperationFactoryBase& add_operation_factory( + const std::string& name, + std::unique_ptr&& ptr) + { + return *(m_factories[name] = std::move(ptr)); + } template - void add_operation_type(const std::string& name) + const operations::OperationFactory& add_operation_type( + const std::string& name, + const operations::OperationSettings& settings) { - m_factories[name] = std::make_unique>(); + return static_cast&>( + add_operation_factory( + name, + std::make_unique>(settings))); } template - void add_operation_type( + const operations::OperationFactory& add_operation_type( const std::string& name, - const operations::OperationSettings& settings) + operations::OperationSettings&& settings) { - m_factories[name] = std::make_unique>(settings); + return static_cast&>( + add_operation_factory( + name, + std::make_unique>( + std::move(settings)))); } + void enqueue_operations(std::vector>&& ops); diff --git a/src/wmtk/Types.hpp b/src/wmtk/Types.hpp index d6b3b39f01..53ff4ff99d 100644 --- a/src/wmtk/Types.hpp +++ b/src/wmtk/Types.hpp @@ -6,11 +6,21 @@ namespace wmtk { template using RowVectors = Eigen::Matrix; +template +using SquareMatrix = Eigen::Matrix; + template using Vector = Eigen::Matrix; template using VectorX = Vector; +template +using Vector2 = Vector; +template +using Vector3 = Vector; +template +using Vector4 = Vector; + template using RowVector = Eigen::Matrix; template diff --git a/src/wmtk/attribute/AttributeHandle.hpp b/src/wmtk/attribute/AttributeHandle.hpp index 0c91ca501a..0938edfbaa 100644 --- a/src/wmtk/attribute/AttributeHandle.hpp +++ b/src/wmtk/attribute/AttributeHandle.hpp @@ -2,7 +2,7 @@ #include #include "wmtk/Primitive.hpp" namespace wmtk { - class Mesh; +class Mesh; namespace attribute { template class MeshAttributes; @@ -36,6 +36,8 @@ class AttributeHandle bool operator==(const AttributeHandle& other) const { return index == other.index; } + + bool is_valid() const { return index != -1; } }; template @@ -71,6 +73,8 @@ class MeshAttributeHandle return std::is_same_v && m_base_handle == o.m_base_handle && m_primitive_type == o.m_primitive_type; } + bool is_valid() const { return m_base_handle.is_valid(); } + PrimitiveType primitive_type() const { return m_primitive_type; } }; } // namespace attribute using AttributeHandle = attribute::AttributeHandle; diff --git a/src/wmtk/attribute/MutableAccessor.hpp b/src/wmtk/attribute/MutableAccessor.hpp index 881a6d4b6d..0179d82884 100644 --- a/src/wmtk/attribute/MutableAccessor.hpp +++ b/src/wmtk/attribute/MutableAccessor.hpp @@ -35,6 +35,8 @@ class MutableAccessor : public ConstAccessor using CachingBaseType::mesh; using CachingBaseType::stack_depth; + using ConstAccessorType::mesh; + protected: using ConstAccessorType::base_type; using ConstAccessorType::caching_base_type; diff --git a/src/wmtk/energy/AMIPS.cpp b/src/wmtk/energy/AMIPS.cpp deleted file mode 100644 index bb16fff54d..0000000000 --- a/src/wmtk/energy/AMIPS.cpp +++ /dev/null @@ -1,106 +0,0 @@ -#include "AMIPS.hpp" - -double AMIPS_2D::energy_eval(const Tuple& tuple) const override -{ - // get the uv coordinates of the triangle - ConstAccessor pos = m_mesh.create_accessor(m_position_handle); - - Eigen::Vector2d uv1 = pos.vector_attribute(tuple); - Eigen::Vector2d uv2 = pos.vector_attribute(switch_edge(switch_vertex(tuple))); - Eigen::Vector2d uv3 = pos.vector_attribute(switch_vertex(switch_edge(tuple))); - - // return the energy - return energy_eval(uv1, uv2, uv3); -} -DScalar AMIPS_2D::energy_eval_autodiff(const Tuple& tuple) const override -{ - // get the uv coordinates of the triangle - ConstAccessor pos = m_mesh.create_accessor(m_position_handle); - - Eigen::Vector2d uv1 = pos.vector_attribute(tuple); - Eigen::Vector2d uv2 = pos.vector_attribute(switch_edge(switch_vertex(tuple))); - Eigen::Vector2d uv3 = pos.vector_attribute(switch_vertex(switch_edge(tuple))); - - // return the energy - return energy_eval_autodiff(uv1, uv2, uv3); -} - -static double AMIPS_2D::energy_eval_autodiff( - const Eigen::Vector2d uv1, - const Eigen::Vector2d uv2, - const Eigen::Vector uv3) -{ - // (x0 - x1, y0 - y1, x0 - x2, y0 - y2).transpose - Eigen::Matrix Dm; - - Dm << uv2(0) - uv1(0), uv3(0) - uv1(0), uv2(1) - uv1(1), uv3(1) - uv1(1); - - Eigen::Matrix2d Ds, Dsinv; - Eigen::Vector2d target_A, target_B, target_C; - target_A << 0., 0.; - target_B << 1., 0.; - target_C << 1. / 2., sqrt(3) / 2.; - Ds << target_B.x() - target_A.x(), target_C.x() - target_A.x(), target_B.y() - target_A.y(), - target_C.y() - target_A.y(); - - auto Dsdet = Ds.determinant(); - if (std::abs(Dsdet) < std::numeric_limits::denorm_min()) { - return std::numeric_limits::infinity(); - } - Dsinv = Ds.inverse(); - - // define of transform matrix F = Dm@Ds.inv - Eigen::Matrix F; - F << (Dm(0, 0) * Dsinv(0, 0) + Dm(0, 1) * Dsinv(1, 0)), - (Dm(0, 0) * Dsinv(0, 1) + Dm(0, 1) * Dsinv(1, 1)), - (Dm(1, 0) * Dsinv(0, 0) + Dm(1, 1) * Dsinv(1, 0)), - (Dm(1, 0) * Dsinv(0, 1) + Dm(1, 1) * Dsinv(1, 1)); - - auto Fdet = F.determinant(); - if (std::abs(Fdet) < std::numeric_limits::denorm_min()) { - return std::numeric_limits::infinity(); - } - return (F.transpose() * F).trace() / Fdet; -} - -static DScalar AMIPS_2D::energy_eval_autodiff( - const Eigen::Vector2d uv1, - const Eigen::Vector2d uv2, - const Eigen::Vector2d uv3) -{ - AMIPS::DScalar x0(0, input_triangle[state.idx * 2]), y0(1, input_triangle[state.idx * 2 + 1]); - - // (x0 - x1, y0 - y1, x0 - x2, y0 - y2).transpose - Eigen::Matrix Dm; - - Dm << input_triangle[(i * 2 + 2) % 6] - x0, input_triangle[(i * 2 + 4) % 6] - x0, - input_triangle[(i * 2 + 3) % 6] - y0, input_triangle[(i * 2 + 5) % 6] - y0; - - // define of transform matrix F = Ds@Dm.inv - Eigen::Matrix F; - - Eigen::Matrix2d Ds, Dsinv; - Ds << target_triangle[(i * 2 + 2) % 6] - target_triangle[(i * 2 + 0) % 6], - target_triangle[(i * 2 + 4) % 6] - target_triangle[(i * 2 + 0) % 6], - target_triangle[(i * 2 + 3) % 6] - target_triangle[(i * 2 + 1) % 6], - target_triangle[(i * 2 + 5) % 6] - target_triangle[(i * 2 + 1) % 6]; - - auto Dsdet = Ds.determinant(); - if (std::abs(Dsdet) < std::numeric_limits::denorm_min()) { - state.value = std::numeric_limits::infinity(); - return; - } - Dsinv = Ds.inverse(); - - F << (Dm(0, 0) * Dsinv(0, 0) + Dm(0, 1) * Dsinv(1, 0)), - (Dm(0, 0) * Dsinv(0, 1) + Dm(0, 1) * Dsinv(1, 1)), - (Dm(1, 0) * Dsinv(0, 0) + Dm(1, 1) * Dsinv(1, 0)), - (Dm(1, 0) * Dsinv(0, 1) + Dm(1, 1) * Dsinv(1, 1)); - - auto Fdet = F.determinant(); - if (std::abs(Fdet.getValue()) < std::numeric_limits::denorm_min()) { - state.value = std::numeric_limits::infinity(); - return; - } - AMIPS::DScalar AMIPS_function = (F.transpose() * F).trace() / Fdet; -} \ No newline at end of file diff --git a/src/wmtk/energy/AMIPS.hpp b/src/wmtk/energy/AMIPS.hpp deleted file mode 100644 index 217c6f1ee2..0000000000 --- a/src/wmtk/energy/AMIPS.hpp +++ /dev/null @@ -1,44 +0,0 @@ -#include "DifferentiableEnergy.hpp" - -class AMIPS_2D : public wmtk::Energy -{ - using DScalar = DScalar2; - -public: - double energy_eval(const Tuple& tuple) const override{}; - DScalar energy_eval_autodiff(const Tuple& tuple) const override{}; - - /** - * @brief gradient defined wrt the first vertex - * - * @param uv1 - * @param uv2 - * @param uv3 - * @return double energy value - */ - static double - energy_eval(const Eigen::Vector2d uv1, const Eigen::Vector2d uv2, const Eigen::Vector2d uv3){}; - - static DScalar energy_eval_autodiff( - const Eigen::Vector2d uv1, - const Eigen::Vector2d uv2, - const Eigen::Vector2d uv3){}; -}; - -/** - * @brief TODO 3D AMIPS uses uv and displacement map to get the 3d cooridnates then evaluate - * - */ -class AMIPS_3D : public wmtk::Energy -{ -public: - double energy_eval(const Tuple& tuple) const override{}; - DScalar energy_eval_autodiff(const Tuple& tuple) const override{}; - - static double - energy_eval(const Eigen::Vector3d p1, const Eigen::Vector3d p2, const Eigen::Vector3d p3){}; - static DScalar energy_eval_autodiff( - const Eigen::Vector3d p1, - const Eigen::Vector3d p2, - const Eigen::Vector3d p3){}; -}; \ No newline at end of file diff --git a/src/wmtk/energy/DifferentiableEnergy.hpp b/src/wmtk/energy/DifferentiableEnergy.hpp deleted file mode 100644 index 0aa10b62be..0000000000 --- a/src/wmtk/energy/DifferentiableEnergy.hpp +++ /dev/null @@ -1,9 +0,0 @@ -#include -#include "Energy.hpp" -class DifferentiableEnergy : public Energy -{ - using DScalar = DScalar2; - -public: - virtual DScalar energy_eval_autodiff(const Tuple& tuple) const = 0; -} \ No newline at end of file diff --git a/src/wmtk/energy/Energy.cpp b/src/wmtk/energy/Energy.cpp deleted file mode 100644 index b4580d4c6e..0000000000 --- a/src/wmtk/energy/Energy.cpp +++ /dev/null @@ -1,9 +0,0 @@ -#include "Energy.hpp" - -Energy::Energy(const Mesh& mesh) - : m_mesh(mesh) - , m_position_handle(m_mesh.get_attribute_handle("position", PrimitiveType::Vertex)){}; - -Energy::Energy(const Mesh& mesh, const MeshAttributeHandle& position_handle) - : m_mesh(mesh) - , m_position_handle(position_handle){}; diff --git a/src/wmtk/energy/Energy.hpp b/src/wmtk/energy/Energy.hpp deleted file mode 100644 index f08318f774..0000000000 --- a/src/wmtk/energy/Energy.hpp +++ /dev/null @@ -1,17 +0,0 @@ -#include -#include - -class Energy -{ -private: - const Mesh& m_mesh; - const MeshAttributeHandle m_position_handle; - - -public: - Energy(const Mesh& mesh, const MeshAttributeHandle& position_handle); - Enegry(const Mesh& mesh); - -public: - virtual double energy_eval(const Tuple& tuple) const = 0; -} \ No newline at end of file diff --git a/src/wmtk/function/AMIPS.cpp b/src/wmtk/function/AMIPS.cpp new file mode 100644 index 0000000000..92658755a0 --- /dev/null +++ b/src/wmtk/function/AMIPS.cpp @@ -0,0 +1,9 @@ +#include "AMIPS.hpp" +#include + +namespace wmtk::function { +AMIPS::AMIPS(const TriMesh& mesh, const MeshAttributeHandle& vertex_attribute_handle) + : AutodiffFunction(mesh, PrimitiveType::Face, vertex_attribute_handle) +{} + +} // namespace wmtk::function diff --git a/src/wmtk/function/AMIPS.hpp b/src/wmtk/function/AMIPS.hpp new file mode 100644 index 0000000000..e2fa9692c3 --- /dev/null +++ b/src/wmtk/function/AMIPS.hpp @@ -0,0 +1,10 @@ +#pragma once +#include "AutodiffFunction.hpp" +namespace wmtk::function { +class AMIPS : public AutodiffFunction +{ +public: + AMIPS(const TriMesh& mesh, const MeshAttributeHandle& vertex_attribute_handle); +}; + +} // namespace wmtk::function diff --git a/src/wmtk/function/AMIPS2D.cpp b/src/wmtk/function/AMIPS2D.cpp new file mode 100644 index 0000000000..35dbcdf3e5 --- /dev/null +++ b/src/wmtk/function/AMIPS2D.cpp @@ -0,0 +1,46 @@ +#include "AMIPS2D.hpp" +#include +#include +#include + + +namespace wmtk::function { +AMIPS2D::AMIPS2D(const TriMesh& mesh, const MeshAttributeHandle& vertex_attribute_handle) + : AMIPS(mesh, vertex_attribute_handle) +{ + assert(get_coordinate_attribute_handle().is_valid()); + // check the dimension of the position + assert(embedded_dimension() == 2); +} + + +auto AMIPS2D::get_value_autodiff(const Tuple& simplex) const -> DScalar +{ + return function_eval(simplex); +} + +template +T AMIPS2D::function_eval(const Tuple& tuple) const +{ + // get_autodiff_value sets the autodiff size if necessary + // get the uv coordinates of the triangle + ConstAccessor pos = mesh().create_const_accessor(get_coordinate_attribute_handle()); + + auto tuple_value = pos.const_vector_attribute(tuple); + Vector2 uv0; + if constexpr (std::is_same_v) { + uv0 = utils::as_DScalar(tuple_value); + } else { + uv0 = tuple_value; + } + constexpr static PrimitiveType PV = PrimitiveType::Vertex; + constexpr static PrimitiveType PE = PrimitiveType::Edge; + + Eigen::Vector2d uv2 = pos.const_vector_attribute(mesh().switch_tuples(tuple, {PE, PV})); + Eigen::Vector2d uv1 = pos.const_vector_attribute(mesh().switch_tuples(tuple, {PV, PE})); + + // return the energy + return utils::amips(uv0, uv1, uv2); +} + +} // namespace wmtk::function diff --git a/src/wmtk/function/AMIPS2D.hpp b/src/wmtk/function/AMIPS2D.hpp new file mode 100644 index 0000000000..ae161d3938 --- /dev/null +++ b/src/wmtk/function/AMIPS2D.hpp @@ -0,0 +1,18 @@ +#pragma once +#include "AMIPS.hpp" +namespace wmtk::function { +class AMIPS2D : public AMIPS +{ +public: + AMIPS2D(const TriMesh& mesh, const MeshAttributeHandle& vertex_attribute_handle); + +protected: + DScalar get_value_autodiff(const Tuple& simplex) const override; + + + template + T function_eval(const Tuple& tuple) const; + +private: +}; +} // namespace wmtk::function diff --git a/src/wmtk/function/AMIPS3D.cpp b/src/wmtk/function/AMIPS3D.cpp new file mode 100644 index 0000000000..7426371e62 --- /dev/null +++ b/src/wmtk/function/AMIPS3D.cpp @@ -0,0 +1,47 @@ +#include "AMIPS3D.hpp" +#include +#include +#include +#include + +namespace wmtk::function { +AMIPS3D::AMIPS3D(const TriMesh& mesh, const MeshAttributeHandle& vertex_attribute_handle) + : AMIPS(mesh, vertex_attribute_handle) +{ + assert(get_coordinate_attribute_handle().is_valid()); + // check the dimension of the position + assert(embedded_dimension() == 3); +} + + +auto AMIPS3D::get_value_autodiff(const Tuple& simplex) const -> DScalar +{ + return function_eval(simplex); +} + +template +T AMIPS3D::function_eval(const Tuple& tuple) const +{ + // get_autodiff_value sets the autodiff size if necessary + // get the pos coordinates of the triangle + ConstAccessor pos = mesh().create_const_accessor(get_coordinate_attribute_handle()); + + auto tuple_value = pos.const_vector_attribute(tuple); + Vector3 pos0; + if constexpr (std::is_same_v) { + pos0 = utils::as_DScalar(tuple_value); + } else { + pos0 = tuple_value; + } + constexpr static PrimitiveType PV = PrimitiveType::Vertex; + constexpr static PrimitiveType PE = PrimitiveType::Edge; + + Eigen::Vector3d pos2 = pos.const_vector_attribute(mesh().switch_tuples(tuple, {PE, PV})); + Eigen::Vector3d pos1 = pos.const_vector_attribute(mesh().switch_tuples(tuple, {PV, PE})); + + // return the energy + return utils::amips(pos0, pos1, pos2); +} + + +} // namespace wmtk::function diff --git a/src/wmtk/function/AMIPS3D.hpp b/src/wmtk/function/AMIPS3D.hpp new file mode 100644 index 0000000000..93f87285ef --- /dev/null +++ b/src/wmtk/function/AMIPS3D.hpp @@ -0,0 +1,19 @@ +#pragma once +#include +#include "AMIPS.hpp" +namespace wmtk::function { +class AMIPS3D : public AMIPS +{ +public: + AMIPS3D(const TriMesh& mesh, const MeshAttributeHandle& vertex_attribute_handle); + +protected: + DScalar get_value_autodiff(const Tuple& simplex) const override; + + template + T function_eval(const Tuple& tuple) const; + + +private: +}; +} // namespace wmtk::function diff --git a/src/wmtk/function/AutodiffFunction.cpp b/src/wmtk/function/AutodiffFunction.cpp new file mode 100644 index 0000000000..ab5048191f --- /dev/null +++ b/src/wmtk/function/AutodiffFunction.cpp @@ -0,0 +1,39 @@ +#include "AutodiffFunction.hpp" +#include + +namespace wmtk::function { + +AutodiffFunction::AutodiffFunction( + const Mesh& mesh, + const PrimitiveType& simplex_type, + const MeshAttributeHandle& variable_attribute_handle) + : PerSimplexDifferentiableFunction(mesh, simplex_type, variable_attribute_handle) +{} + +AutodiffFunction::~AutodiffFunction() = default; + +auto AutodiffFunction::get_value_autodiff(const Simplex& simplex) const -> DScalar +{ + assert(simplex.primitive_type() == get_simplex_type()); + return get_value_autodiff(simplex.tuple()); +} +double AutodiffFunction::get_value(const Tuple& simplex) const +{ + auto scope = utils::AutoDiffRAII(embedded_dimension()); + auto v = get_value_autodiff(simplex); + return v.getValue(); +} + +Eigen::VectorXd AutodiffFunction::get_gradient(const Tuple& simplex) const +{ + auto scope = utils::AutoDiffRAII(embedded_dimension()); + auto v = get_value_autodiff(simplex); + return v.getGradient(); +} +Eigen::MatrixXd AutodiffFunction::get_hessian(const Tuple& simplex) const +{ + auto scope = utils::AutoDiffRAII(embedded_dimension()); + auto v = get_value_autodiff(simplex); + return v.getHessian(); +} +} // namespace wmtk::function diff --git a/src/wmtk/function/AutodiffFunction.hpp b/src/wmtk/function/AutodiffFunction.hpp new file mode 100644 index 0000000000..34dd8916df --- /dev/null +++ b/src/wmtk/function/AutodiffFunction.hpp @@ -0,0 +1,33 @@ +#pragma once +#include +#include "PerSimplexDifferentiableFunction.hpp" +namespace wmtk::function { + +class AutodiffFunction : public PerSimplexDifferentiableFunction +{ +public: + using DScalar = DScalar2, Eigen::Matrix>; + using Scalar = typename DScalar::Scalar; + static_assert( + std::is_same_v); // MTAO: i'm leaving scalar here but is it ever not double? + AutodiffFunction( + const Mesh& mesh, + const PrimitiveType& simplex_type, + const attribute::MeshAttributeHandle& variable_attribute_handle); + + virtual ~AutodiffFunction(); + +public: + using PerSimplexFunction::get_value; + using PerSimplexDifferentiableFunction::get_hessian; + using PerSimplexDifferentiableFunction::get_gradient; + double get_value(const Tuple& tuple) const final override; + Eigen::VectorXd get_gradient(const Tuple& tuple) const final override; + Eigen::MatrixXd get_hessian(const Tuple& tuple) const final override; + + +protected: + virtual DScalar get_value_autodiff(const Tuple& simplex) const = 0; + DScalar get_value_autodiff(const Simplex& simplex) const; +}; +} // namespace wmtk::function diff --git a/src/wmtk/function/CMakeLists.txt b/src/wmtk/function/CMakeLists.txt new file mode 100644 index 0000000000..ca44244947 --- /dev/null +++ b/src/wmtk/function/CMakeLists.txt @@ -0,0 +1,32 @@ + +set(SRC_FILES + Function.cpp + Function.hpp + DifferentiableFunction.cpp + DifferentiableFunction.hpp + + LocalFunction.hpp + LocalFunction.cpp + LocalDifferentiableFunction.hpp + LocalDifferentiableFunction.cpp + PerSimplexFunction.hpp + PerSimplexFunction.cpp + PerSimplexDifferentiableFunction.hpp + PerSimplexDifferentiableFunction.cpp + + AMIPS.hpp + AMIPS.cpp + AMIPS2D.hpp + AMIPS2D.cpp + AMIPS3D.hpp + AMIPS3D.cpp + PositionMapAMIPS2D.hpp + PositionMapAMIPS2D.cpp + ValenceEnergyPerEdge.hpp + ValenceEnergyPerEdge.cpp + + AutodiffFunction.hpp + AutodiffFunction.cpp +) +target_sources(wildmeshing_toolkit PRIVATE ${SRC_FILES}) +add_subdirectory(utils) diff --git a/src/wmtk/function/DifferentiableFunction.cpp b/src/wmtk/function/DifferentiableFunction.cpp new file mode 100644 index 0000000000..a3da334250 --- /dev/null +++ b/src/wmtk/function/DifferentiableFunction.cpp @@ -0,0 +1,19 @@ +#include "DifferentiableFunction.hpp" +#include +namespace wmtk::function { + +//DifferentiableFunction::DifferentiableFunction( +// const MeshAttributeHandle& attribute_handle) +// : m_attribute_handle(attribute_handle) +//{} +// +// +//const MeshAttributeHandle& DifferentiableFunction::get_coordinate_attribute_handle() const +//{ +// return m_coordinate_attribute_handle; +//} +long DifferentiableFunction::embedded_dimension() const +{ + return mesh().get_attribute_dimension(get_coordinate_attribute_handle()); +} +} // namespace wmtk::function diff --git a/src/wmtk/function/DifferentiableFunction.hpp b/src/wmtk/function/DifferentiableFunction.hpp new file mode 100644 index 0000000000..b01ae34f7e --- /dev/null +++ b/src/wmtk/function/DifferentiableFunction.hpp @@ -0,0 +1,22 @@ +#pragma once +#include +#include +#include +#include "Function.hpp" +namespace wmtk::function { + +class DifferentiableFunction : public virtual Function +{ +public: + // evaluates the gradient of the tuple + + virtual Eigen::VectorXd get_gradient(const simplex::Simplex& tuple) const = 0; + + // TODO: should differentiable function be required to be twice differentiable? + virtual Eigen::MatrixXd get_hessian(const simplex::Simplex& tuple) const = 0; + + long embedded_dimension() const; + virtual MeshAttributeHandle get_coordinate_attribute_handle() const = 0; + +}; +} // namespace wmtk::function diff --git a/src/wmtk/function/Function.cpp b/src/wmtk/function/Function.cpp new file mode 100644 index 0000000000..12473a4152 --- /dev/null +++ b/src/wmtk/function/Function.cpp @@ -0,0 +1,7 @@ +#include "Function.hpp" +namespace wmtk::function { + +// Function::Function() {} + +Function::~Function() = default; +} // namespace wmtk::function diff --git a/src/wmtk/function/Function.hpp b/src/wmtk/function/Function.hpp new file mode 100644 index 0000000000..71d65ee4f1 --- /dev/null +++ b/src/wmtk/function/Function.hpp @@ -0,0 +1,20 @@ +#pragma once +#include +#include +namespace wmtk { +class Mesh; +namespace simplex { +class Simplex; +} + +} // namespace wmtk +namespace wmtk::function { +class Function +{ +public: + virtual ~Function(); + // evaluate the function on the top level simplex of the tuple + virtual double get_value(const simplex::Simplex& simplex) const = 0; + virtual const Mesh& mesh() const = 0; +}; +} // namespace wmtk::function diff --git a/src/wmtk/function/LocalDifferentiableFunction.cpp b/src/wmtk/function/LocalDifferentiableFunction.cpp new file mode 100644 index 0000000000..2cebd3a62f --- /dev/null +++ b/src/wmtk/function/LocalDifferentiableFunction.cpp @@ -0,0 +1,57 @@ +#include "LocalDifferentiableFunction.hpp" +#include +#include +#include +#include "PerSimplexDifferentiableFunction.hpp" +namespace wmtk::function { +LocalDifferentiableFunction::LocalDifferentiableFunction( + std::shared_ptr function, + const PrimitiveType& simplex_type) + : LocalFunction(std::move(function)) + , m_simplex_type(simplex_type) +{ +} + +LocalDifferentiableFunction::~LocalDifferentiableFunction() = default; + +Eigen::VectorXd LocalDifferentiableFunction::get_gradient(const Tuple& tuple) const +{ + return get_gradient(Simplex(m_simplex_type, tuple)); +} + +Eigen::MatrixXd LocalDifferentiableFunction::get_hessian(const Tuple& tuple) const +{ + return get_hessian(Simplex(m_simplex_type, tuple)); +} + +Eigen::VectorXd LocalDifferentiableFunction::get_gradient(const Simplex& simplex) const +{ + return per_simplex_function().get_gradient_sum(get_local_neighborhood_tuples(simplex)); +} +Eigen::MatrixXd LocalDifferentiableFunction::get_hessian(const Simplex& simplex) const +{ + + return per_simplex_function().get_hessian_sum(get_local_neighborhood_tuples(simplex)); +} + +const PerSimplexDifferentiableFunction& LocalDifferentiableFunction::per_simplex_function() const +{ + //return m_function; + return static_cast( + LocalFunction::per_simplex_function()); +} +std::shared_ptr +LocalDifferentiableFunction::per_simplex_function_ptr() const +{ + return std::static_pointer_cast( + LocalFunction::per_simplex_function_ptr()); +} + + + +MeshAttributeHandle +LocalDifferentiableFunction::get_coordinate_attribute_handle() const +{ + return per_simplex_function().get_coordinate_attribute_handle(); +} +} // namespace wmtk::function diff --git a/src/wmtk/function/LocalDifferentiableFunction.hpp b/src/wmtk/function/LocalDifferentiableFunction.hpp new file mode 100644 index 0000000000..bdb2eee851 --- /dev/null +++ b/src/wmtk/function/LocalDifferentiableFunction.hpp @@ -0,0 +1,36 @@ +#pragma once +#include +#include +#include "DifferentiableFunction.hpp" +#include "LocalFunction.hpp" +namespace wmtk::function { +class PerSimplexDifferentiableFunction; +; +class LocalDifferentiableFunction : public LocalFunction, public DifferentiableFunction +{ +public: + LocalDifferentiableFunction( + std::shared_ptr function, + const PrimitiveType& simplex_type); + virtual ~LocalDifferentiableFunction(); + +public: + //Eigen::VectorXd get_local_gradient(const Simplex& simplex) const override; + //Eigen::MatrixXd get_local_hessian(const Simplex& simplex) const override; + //Eigen::VectorXd get_local_gradient(const Tuple& tuple) const; + //Eigen::MatrixXd get_local_hessian(const Tuple& tuple) const; + + Eigen::VectorXd get_gradient(const Simplex& simplex) const override; + Eigen::MatrixXd get_hessian(const Simplex& simplex) const override; + Eigen::VectorXd get_gradient(const Tuple& tuple) const; + Eigen::MatrixXd get_hessian(const Tuple& tuple) const; + + const PerSimplexDifferentiableFunction& per_simplex_function() const; + std::shared_ptr per_simplex_function_ptr() const; + + attribute::MeshAttributeHandle get_coordinate_attribute_handle() const final override; + +private: + const PrimitiveType m_simplex_type; +}; +} // namespace wmtk::function diff --git a/src/wmtk/function/LocalFunction.cpp b/src/wmtk/function/LocalFunction.cpp new file mode 100644 index 0000000000..e951e8fa8e --- /dev/null +++ b/src/wmtk/function/LocalFunction.cpp @@ -0,0 +1,50 @@ + +#include "LocalFunction.hpp" +#include + +namespace wmtk::function { + +LocalFunction::LocalFunction(std::shared_ptr function) + : m_function(std::move(function)) +{} + +LocalFunction::~LocalFunction() = default; + +const PerSimplexFunction& LocalFunction::per_simplex_function() const +{ + return *per_simplex_function_ptr(); +} + +std::shared_ptr LocalFunction::per_simplex_function_ptr() const +{ + return m_function; +} + +const Mesh& LocalFunction::mesh() const +{ + return per_simplex_function().mesh(); +} +std::vector LocalFunction::get_local_neighborhood_tuples(const Simplex& simplex) const +{ + return wmtk::simplex::cofaces_single_dimension_tuples( + m_function->mesh(), + simplex, + per_simplex_function().get_simplex_type()); +} + +double LocalFunction::get_value(const Simplex& simplex) const +{ + return per_simplex_function().get_value_sum(get_local_neighborhood_tuples(simplex)); +} + +PrimitiveType LocalFunction::get_simplex_type() const +{ + return per_simplex_function().get_simplex_type(); +} + +double LocalFunction::get_value(const Tuple& simplex) const +{ + return get_value(Simplex(get_simplex_type(), simplex)); +} + +} // namespace wmtk::function diff --git a/src/wmtk/function/LocalFunction.hpp b/src/wmtk/function/LocalFunction.hpp new file mode 100644 index 0000000000..a3798ad32d --- /dev/null +++ b/src/wmtk/function/LocalFunction.hpp @@ -0,0 +1,37 @@ +#pragma once +#include +#include +#include "PerSimplexFunction.hpp" +namespace wmtk::function { + +// a function that invokes a function on a local neighborhood of an input simplex +// Typically we will want to compute something like the gradient of a function defined on the triangles/tetrahedra with respect to a vertex and the choice of basis functions means we only need to compute a one-ring neighborhood. +// This class lets us select a per-simplex function (tet or tri) and evaluate that function in a +// (one-ring) neighborhood of an input simplex(vertex). This class will allow for us to evaluate how +// a function changes, which might be useful fo ra line-search, but the real purpose is to define an +// interface for a differentiable variant that enables the use of gradient descent. + +class LocalFunction : public virtual Function +{ +public: + LocalFunction(std::shared_ptr function); + virtual ~LocalFunction(); + +public: + // evaluate the function on the top level simplex of the tuple + double get_value(const Simplex& simplex) const override; + const Mesh& mesh() const final override; + double get_value(const Tuple& simplex) const; + const PerSimplexFunction& per_simplex_function() const; + std::shared_ptr per_simplex_function_ptr() const; + + PrimitiveType get_simplex_type() const; + +protected: + std::vector get_local_neighborhood_tuples(const Simplex& simplex) const; + + +private: + std::shared_ptr m_function; +}; +} // namespace wmtk::function diff --git a/src/wmtk/function/PerSimplexDifferentiableFunction.cpp b/src/wmtk/function/PerSimplexDifferentiableFunction.cpp new file mode 100644 index 0000000000..b9b22903a8 --- /dev/null +++ b/src/wmtk/function/PerSimplexDifferentiableFunction.cpp @@ -0,0 +1,73 @@ +#include "PerSimplexDifferentiableFunction.hpp" + +namespace wmtk::function { +PerSimplexDifferentiableFunction::PerSimplexDifferentiableFunction( + const Mesh& mesh, + PrimitiveType simplex_type, + const attribute::MeshAttributeHandle& variable_attribute_handle) + : PerSimplexFunction(mesh, simplex_type) + , m_coordinate_attribute_handle(variable_attribute_handle) +{} + +PerSimplexDifferentiableFunction::~PerSimplexDifferentiableFunction() = default; + +MeshAttributeHandle +PerSimplexDifferentiableFunction::get_coordinate_attribute_handle() const +{ + return m_coordinate_attribute_handle; +} + +long PerSimplexDifferentiableFunction::embedded_dimension() const +{ + return DifferentiableFunction::embedded_dimension(); +} + +Eigen::VectorXd PerSimplexDifferentiableFunction::get_gradient(const Simplex& s) const +{ + assert(get_simplex_type() == s.primitive_type()); + return get_gradient(s.tuple()); +} +Eigen::MatrixXd PerSimplexDifferentiableFunction::get_hessian(const Simplex& s) const +{ + assert(get_simplex_type() == s.primitive_type()); + return get_hessian(s.tuple()); +} + +Eigen::VectorXd PerSimplexDifferentiableFunction::get_gradient_sum( + const std::vector& simplices) const +{ + Eigen::VectorXd g = Eigen::VectorXd::Zero(embedded_dimension()); + for (const Simplex& cell : simplices) { + g += get_gradient(cell); + } + return g; +} +Eigen::MatrixXd PerSimplexDifferentiableFunction::get_hessian_sum( + const std::vector& simplices) const +{ + Eigen::MatrixXd h = Eigen::MatrixXd::Zero(embedded_dimension(), embedded_dimension()); + for (const Simplex& cell : simplices) { + h += get_hessian(cell); + } + return h; +} + +Eigen::VectorXd PerSimplexDifferentiableFunction::get_gradient_sum( + const std::vector& tuples) const +{ + Eigen::VectorXd g = Eigen::VectorXd::Zero(embedded_dimension()); + for (const Tuple& t : tuples) { + g += get_gradient(t); + } + return g; +} +Eigen::MatrixXd PerSimplexDifferentiableFunction::get_hessian_sum( + const std::vector& tuples) const +{ + Eigen::MatrixXd h = Eigen::MatrixXd::Zero(embedded_dimension(), embedded_dimension()); + for (const Tuple& t : tuples) { + h += get_hessian(t); + } + return h; +} +} // namespace wmtk::function diff --git a/src/wmtk/function/PerSimplexDifferentiableFunction.hpp b/src/wmtk/function/PerSimplexDifferentiableFunction.hpp new file mode 100644 index 0000000000..00e0845627 --- /dev/null +++ b/src/wmtk/function/PerSimplexDifferentiableFunction.hpp @@ -0,0 +1,50 @@ +#pragma once +#include +#include +#include +#include "DifferentiableFunction.hpp" +#include "PerSimplexFunction.hpp" +namespace wmtk { +namespace function { +class PerSimplexDifferentiableFunction : public PerSimplexFunction, DifferentiableFunction +{ +public: + /** + * @brief Construct a new PerSimplexDifferentiableFunction object where the function is defined + * over simplices of simplex_type. And the differentiation is taken wrt the + * attribute_handle.primitive_type() + * + * @param mesh + * @param simplex_type + * @param attribute_handle, the attribute that differentiation is with respect to + */ + PerSimplexDifferentiableFunction( + const Mesh& mesh, + PrimitiveType simplex_type, + const attribute::MeshAttributeHandle& attribute_handle); + virtual ~PerSimplexDifferentiableFunction(); + +public: + virtual Eigen::VectorXd get_gradient(const Tuple& s) const = 0; + virtual Eigen::MatrixXd get_hessian(const Tuple& s) const = 0; + Eigen::VectorXd get_gradient(const Simplex& s) const final override; + Eigen::MatrixXd get_hessian(const Simplex& s) const final override; + + attribute::MeshAttributeHandle get_coordinate_attribute_handle() const final override; + + + long embedded_dimension() const; + + // computes the sum over a set of simplices - assumes each simplex has the same dimension as the + // function's simplex type + Eigen::VectorXd get_gradient_sum(const std::vector& simplices) const; + Eigen::MatrixXd get_hessian_sum(const std::vector& simplices) const; + + Eigen::VectorXd get_gradient_sum(const std::vector& simplices) const; + Eigen::MatrixXd get_hessian_sum(const std::vector& simplices) const; + +private: + const MeshAttributeHandle m_coordinate_attribute_handle; +}; +} // namespace function +} // namespace wmtk diff --git a/src/wmtk/function/PerSimplexFunction.cpp b/src/wmtk/function/PerSimplexFunction.cpp new file mode 100644 index 0000000000..0500ae698c --- /dev/null +++ b/src/wmtk/function/PerSimplexFunction.cpp @@ -0,0 +1,45 @@ +#include "PerSimplexFunction.hpp" + +namespace wmtk::function { + +PerSimplexFunction::PerSimplexFunction(const Mesh& mesh, const PrimitiveType& simplex_type) + : m_mesh(mesh) + , m_simplex_type(simplex_type) +{} + +PerSimplexFunction::~PerSimplexFunction() = default; + +const Mesh& PerSimplexFunction::mesh() const +{ + return m_mesh; +} + +PrimitiveType PerSimplexFunction::get_simplex_type() const +{ + return m_simplex_type; +} + + +double PerSimplexFunction::get_value(const Simplex& s) const +{ + assert(get_simplex_type() == s.primitive_type()); + return get_value(s.tuple()); +} +double PerSimplexFunction::get_value_sum(const std::vector& simplices) const +{ + double v = 0; + for (const Simplex& cell : simplices) { + v += get_value(cell); + } + return v; +} +double PerSimplexFunction::get_value_sum(const std::vector& tuples) const +{ + double v = 0; + const PrimitiveType pt = get_simplex_type(); + for (const Tuple& tuple : tuples) { + v += get_value(tuple); + } + return v; +} +}; // namespace wmtk::function diff --git a/src/wmtk/function/PerSimplexFunction.hpp b/src/wmtk/function/PerSimplexFunction.hpp new file mode 100644 index 0000000000..406d55bfea --- /dev/null +++ b/src/wmtk/function/PerSimplexFunction.hpp @@ -0,0 +1,32 @@ +#pragma once +#include +#include +#include +#include +#include "Function.hpp" +namespace wmtk::function { +class PerSimplexFunction : public virtual Function +{ +public: + PerSimplexFunction(const Mesh& mesh, const PrimitiveType& simplex_type); + virtual ~PerSimplexFunction(); + +public: + using Function::get_value; + const Mesh& mesh() const final override; + virtual double get_value(const Tuple& s) const = 0; + double get_value(const simplex::Simplex& s) const final override; + // the type of simplex that this function operates on + PrimitiveType get_simplex_type() const; + + // helper because in many cases we want to compute the value of multiple simplices at once + double get_value_sum(const std::vector& simplices) const; + // assumes that the underlying simplices are all of the same as get_simplex_type() + double get_value_sum(const std::vector& tuples) const; + +private: + const Mesh& m_mesh; + const PrimitiveType m_simplex_type; +}; + +} // namespace wmtk::function diff --git a/src/wmtk/function/PositionMapAMIPS2D.cpp b/src/wmtk/function/PositionMapAMIPS2D.cpp new file mode 100644 index 0000000000..0b04fd2483 --- /dev/null +++ b/src/wmtk/function/PositionMapAMIPS2D.cpp @@ -0,0 +1,52 @@ +#include "PositionMapAMIPS2D.hpp" +#include +#include +#include + +namespace wmtk::function { +PositionMapAMIPS2D::PositionMapAMIPS2D( + const TriMesh& mesh, + const MeshAttributeHandle& vertex_uv_handle, + const image::Image& image) + : AMIPS(mesh, vertex_uv_handle) + , m_pos_evaluator(image) +{} + +PositionMapAMIPS2D::PositionMapAMIPS2D( + const TriMesh& mesh, + const MeshAttributeHandle& vertex_uv_handle, + const wmtk::image::SamplingAnalyticFunction::FunctionType type, + const double a, + const double b, + const double c) + : AMIPS(mesh, vertex_uv_handle) + , m_pos_evaluator(type, a, b, c) +{} + + +auto PositionMapAMIPS2D::get_value_autodiff(const Tuple& simplex) const -> DScalar +{ + // get_autodiff_value sets the autodiff size if necessary + // get the uv coordinates of the triangle + ConstAccessor pos = mesh().create_const_accessor(get_coordinate_attribute_handle()); + + const Tuple& tuple = simplex; + auto tuple_value = pos.const_vector_attribute(tuple); + + Vector2 uv0; + uv0 = utils::as_DScalar(tuple_value); + + constexpr static PrimitiveType PV = PrimitiveType::Vertex; + constexpr static PrimitiveType PE = PrimitiveType::Edge; + + Eigen::Vector2d uv2 = pos.const_vector_attribute(mesh().switch_tuples(tuple, {PE, PV})); + Eigen::Vector2d uv1 = pos.const_vector_attribute(mesh().switch_tuples(tuple, {PV, PE})); + + Vector3 pos0 = m_pos_evaluator.uv_to_pos(uv0); + Eigen::Vector3d pos1 = m_pos_evaluator.uv_to_pos(uv1); + Eigen::Vector3d pos2 = m_pos_evaluator.uv_to_pos(uv2); + + return utils::amips(pos0, pos1, pos2); +} + +} // namespace wmtk::function diff --git a/src/wmtk/function/PositionMapAMIPS2D.hpp b/src/wmtk/function/PositionMapAMIPS2D.hpp new file mode 100644 index 0000000000..6a6e838643 --- /dev/null +++ b/src/wmtk/function/PositionMapAMIPS2D.hpp @@ -0,0 +1,31 @@ +#pragma once +#include +#include "AMIPS.hpp" + +namespace wmtk::function { +/** + * @brief 2D AMIPS uses uv and position map to get the 3d cooridnates then evaluate + * + */ +class PositionMapAMIPS2D : public AMIPS +{ +public: + PositionMapAMIPS2D( + const TriMesh& mesh, + const MeshAttributeHandle& vertex_uv_handle, + const image::Image& image); + PositionMapAMIPS2D( + const TriMesh& mesh, + const MeshAttributeHandle& vertex_uv_handle, + const wmtk::image::SamplingAnalyticFunction::FunctionType type, + const double a, + const double b, + const double c); + +public: + DScalar get_value_autodiff(const Tuple& simplex) const override; + +protected: + utils::PositionMapEvaluator m_pos_evaluator; +}; +} // namespace wmtk::function diff --git a/src/wmtk/function/ValenceEnergyPerEdge.cpp b/src/wmtk/function/ValenceEnergyPerEdge.cpp new file mode 100644 index 0000000000..1f5553d634 --- /dev/null +++ b/src/wmtk/function/ValenceEnergyPerEdge.cpp @@ -0,0 +1,62 @@ +#include "ValenceEnergyPerEdge.hpp" +#include +#include +#include +namespace wmtk::function { +ValenceEnergyPerEdge::ValenceEnergyPerEdge(const TriMesh& mesh) + : PerSimplexFunction(mesh, PrimitiveType::Edge) +{} + +double ValenceEnergyPerEdge::get_value(const Tuple& tuple) const +{ + // assume tuple is not a boundary edge + const Tuple& current_v = tuple; + const Tuple other_v = tri_mesh().switch_vertex(current_v); + long val0 = static_cast(SimplicialComplex::vertex_one_ring(tri_mesh(), current_v).size()); + long val1 = static_cast(SimplicialComplex::vertex_one_ring(tri_mesh(), other_v).size()); + if (tri_mesh().is_boundary_vertex(current_v)) { + val0 += 2; + } + if (tri_mesh().is_boundary_vertex(other_v)) { + val1 += 2; + } + if (val0 < 4 || val1 < 4) { + return -1; + } + + /* top_v + // / \ + // / \ + // current_v-----other_v + // \ / + // \ / + // bottom_v + */ + const Tuple top_v = tri_mesh().switch_vertex(tri_mesh().switch_edge(current_v)); + const Tuple bottom_v = + tri_mesh().switch_vertex(tri_mesh().switch_edge(tri_mesh().switch_face(current_v))); + long val2 = static_cast(SimplicialComplex::vertex_one_ring(tri_mesh(), top_v).size()); + long val3 = static_cast(SimplicialComplex::vertex_one_ring(tri_mesh(), bottom_v).size()); + + if (tri_mesh().is_boundary_vertex(top_v)) { + val2 += 2; + } + if (tri_mesh().is_boundary_vertex(bottom_v)) { + val3 += 2; + } + // formula from: https://github.com/daniel-zint/hpmeshgen/blob/cdfb9163ed92523fcf41a127c8173097e935c0a3/src/HPMeshGen2/TriRemeshing.cpp#L315 + const long val_energy = std::max(std::abs(val0 - 6), std::abs(val1 - 6)) + + std::max(std::abs(val2 - 6), std::abs(val3 - 6)); + // const long val_after = std::max(std::abs(val0 - 7), std::abs(val1 - 7)) + + // std::max(std::abs(val2 - 5), std::abs(val3 - 5)); + + return static_cast(val_energy); +} + +const TriMesh& ValenceEnergyPerEdge::tri_mesh() const +{ + return static_cast(PerSimplexFunction::mesh()); +} + + +} // namespace wmtk::function diff --git a/src/wmtk/function/ValenceEnergyPerEdge.hpp b/src/wmtk/function/ValenceEnergyPerEdge.hpp new file mode 100644 index 0000000000..d3e464c203 --- /dev/null +++ b/src/wmtk/function/ValenceEnergyPerEdge.hpp @@ -0,0 +1,16 @@ +#pragma once +#include "PerSimplexFunction.hpp" + +namespace wmtk::function { + +class ValenceEnergyPerEdge : public PerSimplexFunction +{ +public: + ValenceEnergyPerEdge(const TriMesh& mesh); + double get_value(const Tuple& simplex) const override; + using PerSimplexFunction::get_value; + +protected: + const TriMesh& tri_mesh() const; +}; +} // namespace wmtk diff --git a/src/wmtk/function/utils/AutoDiffRAII.cpp b/src/wmtk/function/utils/AutoDiffRAII.cpp new file mode 100644 index 0000000000..ad0db2a3ec --- /dev/null +++ b/src/wmtk/function/utils/AutoDiffRAII.cpp @@ -0,0 +1,14 @@ +#include "AutoDiffRAII.hpp" +#include "autodiff.h" + +namespace wmtk::function::utils { +AutoDiffRAII::AutoDiffRAII(size_t size) + : m_previous_variable_count(DiffScalarBase::getVariableCount()) +{ + DiffScalarBase::setVariableCount(size); +} +AutoDiffRAII::~AutoDiffRAII() +{ + DiffScalarBase::setVariableCount(m_previous_variable_count); +} +} // namespace wmtk::function::utils diff --git a/src/wmtk/function/utils/AutoDiffRAII.hpp b/src/wmtk/function/utils/AutoDiffRAII.hpp new file mode 100644 index 0000000000..25792a6db3 --- /dev/null +++ b/src/wmtk/function/utils/AutoDiffRAII.hpp @@ -0,0 +1,17 @@ +#pragma once +#include + + +namespace wmtk::function::utils { + + +class AutoDiffRAII +{ +public: + AutoDiffRAII(size_t size); + ~AutoDiffRAII(); + +private: + size_t m_previous_variable_count; +}; +} // namespace wmtk::function diff --git a/src/wmtk/function/utils/AutoDiffUtils.hpp b/src/wmtk/function/utils/AutoDiffUtils.hpp new file mode 100644 index 0000000000..05ff287432 --- /dev/null +++ b/src/wmtk/function/utils/AutoDiffUtils.hpp @@ -0,0 +1,87 @@ +#pragma once +#include "autodiff.h" + +namespace wmtk::function::utils { + + +template +auto make_DScalar_matrix(int rows = 0, int cols = 0) +{ + if constexpr (Rows != Eigen::Dynamic) { + rows = Rows; + } + if constexpr (Cols != Eigen::Dynamic) { + cols = Cols; + } + assert(rows * cols == DiffScalarBase::getVariableCount()); + + using RetType = Eigen::Matrix; + if constexpr (Rows != Eigen::Dynamic && Cols != Eigen::Dynamic) { + return RetType::NullaryExpr([](int row, int col) { + int index; + if constexpr (RetType::IsRowMajor) { + index = Rows * col + row; + } else { + index = Cols * row + col; + } + return DScalarType(index); + }) + .eval(); + } else { + return RetType::NullaryExpr( + rows, + cols, + [&](int row, int col) { + int index; + if constexpr (RetType::IsRowMajor) { + index = rows * col + row; + } else { + index = cols * row + col; + } + return DScalarType(index); + }) + .eval(); + } +} + +template +auto as_DScalar(const Eigen::MatrixBase& data) +{ + constexpr static int Rows = Derived::RowsAtCompileTime; + constexpr static int Cols = Derived::ColsAtCompileTime; + int rows = data.rows(); + int cols = data.cols(); + + assert(rows * cols == DiffScalarBase::getVariableCount()); + + using RetType = Eigen::Matrix; + if constexpr (Rows != Eigen::Dynamic && Cols != Eigen::Dynamic) { + return RetType::NullaryExpr([&](int row, int col) { + int index; + if constexpr (RetType::IsRowMajor) { + index = Rows * col + row; + } else { + index = Cols * row + col; + } + return DScalarType(index, data(row, col)); + }) + .eval(); + } else { + return RetType::NullaryExpr( + rows, + cols, + [&](int row, int col) { + int index; + if constexpr (RetType::IsRowMajor) { + index = rows * col + row; + } else { + index = cols * row + col; + } + return DScalarType(index, data(row, col)); + }) + .eval(); + } +} + + +} // namespace wmtk::function::utils diff --git a/src/wmtk/function/utils/CMakeLists.txt b/src/wmtk/function/utils/CMakeLists.txt new file mode 100644 index 0000000000..c5dcdc0036 --- /dev/null +++ b/src/wmtk/function/utils/CMakeLists.txt @@ -0,0 +1,17 @@ +set(SRC_FILES + AutoDiffUtils.hpp + autodiff.h + autodiff.cpp + PositionMapEvaluator.hpp + PositionMapEvaluator.cpp + AutoDiffRAII.hpp + AutoDiffRAII.cpp + amips.hpp + amips.cpp + + FunctionEvaluator.hpp + FunctionEvaluator.cpp + DifferentiableFunctionEvaluator.hpp + DifferentiableFunctionEvaluator.cpp +) +target_sources(wildmeshing_toolkit PRIVATE ${SRC_FILES}) diff --git a/src/wmtk/function/utils/DifferentiableFunctionEvaluator.cpp b/src/wmtk/function/utils/DifferentiableFunctionEvaluator.cpp new file mode 100644 index 0000000000..5a9ad41a7a --- /dev/null +++ b/src/wmtk/function/utils/DifferentiableFunctionEvaluator.cpp @@ -0,0 +1,77 @@ +#include "DifferentiableFunctionEvaluator.hpp" +#include +#include +#include + + +namespace wmtk::function::utils { +DifferentiableFunctionEvaluator::DifferentiableFunctionEvaluator( + const function::DifferentiableFunction& function, + Accessor& accessor, + const Simplex& simplex) + : FunctionEvaluator(function, accessor, simplex) + , m_function(function) +{ + // m_cofaces_single_dimension = compute_cofaces_single_dimension(); +} + +auto DifferentiableFunctionEvaluator::function() const -> const function::DifferentiableFunction& +{ + return m_function; + // return static_cast(FunctionEvaluator::function()); +} + + +// auto DifferentiableFunctionEvaluator::get_value() const -> double +//{ +// return get_value(simplex()); +// // return function().get_value_sum( +// // wmtk::simplex::utils::tuple_vector_to_homogeneous_simplex_vector( +// // cofaces_single_dimension(), +// // function_simplex_type())); +//} + +auto DifferentiableFunctionEvaluator::get_gradient() const -> Vector +{ + return function().get_hessian(simplex()); + // return function().get_gradient_sum( + // wmtk::simplex::utils::tuple_vector_to_homogeneous_simplex_vector( + // cofaces_single_dimension(), + // function_simplex_type())); +} + +auto DifferentiableFunctionEvaluator::get_hessian() const -> Matrix +{ + return function().get_hessian(simplex()); + // return function().get_hessian_sum( + // wmtk::simplex::utils::tuple_vector_to_homogeneous_simplex_vector( + // cofaces_single_dimension(), + // function_simplex_type())); +} + +auto DifferentiableFunctionEvaluator::get_gradient(double v) -> Vector +{ + store(v); + return get_gradient(); +} + +auto DifferentiableFunctionEvaluator::get_hessian(double v) -> Matrix +{ + store(v); + return get_hessian(); +} + +// const std::vector& DifferentiableFunctionEvaluator::cofaces_single_dimension() const +//{ +// return m_cofaces_single_dimension; +// } +// std::vector DifferentiableFunctionEvaluator::compute_cofaces_single_dimension() const +//{ +// return simplex::cofaces_single_dimension_tuples(mesh(), simplex(), function_simplex_type()); +// } +// +// std::vector DifferentiableFunctionEvaluator::compute_top_level_cofaces() const +//{ +// return simplex::top_level_cofaces_tuples(mesh(), simplex()); +// } +} // namespace wmtk::function::utils diff --git a/src/wmtk/function/utils/DifferentiableFunctionEvaluator.hpp b/src/wmtk/function/utils/DifferentiableFunctionEvaluator.hpp new file mode 100644 index 0000000000..780ec49f81 --- /dev/null +++ b/src/wmtk/function/utils/DifferentiableFunctionEvaluator.hpp @@ -0,0 +1,62 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include "FunctionEvaluator.hpp" + +namespace wmtk::function::utils { + + +// Evaluates a function at a particular vertex of a mesh +// NOTE that this modifies attributes in the mesh. +// This should only be called from within a Scope so evaluates can be undone +// +class DifferentiableFunctionEvaluator : public FunctionEvaluator +{ +public: + DifferentiableFunctionEvaluator( + const function::DifferentiableFunction& function, + Accessor& accessor, + const Simplex& simplex); + using Vector = Eigen::VectorXd; + using Matrix = Eigen::MatrixXd; + + Vector get_gradient() const; + Matrix get_hessian() const; + + + template + Vector get_gradient(const Eigen::MatrixBase& v); + template + Matrix get_hessian(const Eigen::MatrixBase& v); + + Vector get_gradient(double v); + Matrix get_hessian(double v); + const function::DifferentiableFunction& function() const; + + const std::vector& cofaces_single_dimension() const; + +private: + // cache the top simplices + const function::DifferentiableFunction& m_function; + // std::vector m_cofaces_single_dimension; + // std::vector compute_cofaces_single_dimension() const; + // std::vector compute_top_level_cofaces() const; +}; + +template +auto DifferentiableFunctionEvaluator::get_gradient(const Eigen::MatrixBase& v) -> Vector +{ + store(v); + return get_gradient(); +} +template +auto DifferentiableFunctionEvaluator::get_hessian(const Eigen::MatrixBase& v) -> Matrix +{ + store(v); + return get_hessian(); +} +} // namespace wmtk::function::utils diff --git a/src/wmtk/function/utils/FunctionEvaluator.cpp b/src/wmtk/function/utils/FunctionEvaluator.cpp new file mode 100644 index 0000000000..6144407f25 --- /dev/null +++ b/src/wmtk/function/utils/FunctionEvaluator.cpp @@ -0,0 +1,30 @@ +#include "FunctionEvaluator.hpp" +#include +#include +namespace wmtk::function::utils { + +FunctionEvaluator::FunctionEvaluator( + const function::Function& function, + Accessor& accessor, + const Simplex& simplex) + : m_function(function) + , m_accessor(accessor) + , m_simplex(simplex) +{} + +void FunctionEvaluator::store(double v) +{ + m_accessor.scalar_attribute(tuple()) = v; +} + + +double FunctionEvaluator::get_value() const +{ + return m_function.get_value(m_simplex); +} +auto FunctionEvaluator::get_value(double v) -> double +{ + store(v); + return get_value(); +} +} // namespace wmtk::function::utils diff --git a/src/wmtk/function/utils/FunctionEvaluator.hpp b/src/wmtk/function/utils/FunctionEvaluator.hpp new file mode 100644 index 0000000000..1cdd88b07f --- /dev/null +++ b/src/wmtk/function/utils/FunctionEvaluator.hpp @@ -0,0 +1,76 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include + +namespace wmtk::function::utils { + + +// Evaluates a function at a particular vertex of a mesh +// NOTE that this modifies attributes in the mesh. +// This should only be called from within a Scope so evaluates can be undone +// +class FunctionEvaluator +{ +public: + FunctionEvaluator(const Function& function, Accessor& accessor, const Simplex& simplex); + + + auto get_coordinate() { return m_accessor.vector_attribute(m_simplex.tuple()); } + auto get_const_coordinate() const + { + return m_accessor.const_vector_attribute(m_simplex.tuple()); + } + + void store(double v); + template + void store(const Eigen::MatrixBase& v); + + auto get_coordinate() const { return get_const_coordinate(); } + + double get_value() const; + + template + double get_value(const Eigen::MatrixBase& v); + + double get_value(double v); + + + const Tuple& tuple() const { return m_simplex.tuple(); } + const Simplex& simplex() const { return m_simplex; } + Mesh& mesh() { return m_accessor.mesh(); } + const Mesh& mesh() const { return m_accessor.mesh(); } + Accessor& accessor() { return m_accessor; } + + const Function& function() const { return m_function; } + + PrimitiveType my_simplex_type() const + { + PrimitiveType type = m_simplex.primitive_type(); + return type; + } + +private: + const Function& m_function; + Accessor& m_accessor; + const Simplex& m_simplex; +}; + + +template +void FunctionEvaluator::store(const Eigen::MatrixBase& v) +{ + m_accessor.vector_attribute(tuple()) = v; +} +template +double FunctionEvaluator::get_value(const Eigen::MatrixBase& v) +{ + store(v); + return get_value(); +} +} // namespace wmtk::function::utils diff --git a/src/wmtk/function/utils/PositionMapEvaluator.cpp b/src/wmtk/function/utils/PositionMapEvaluator.cpp new file mode 100644 index 0000000000..1ad636ad7f --- /dev/null +++ b/src/wmtk/function/utils/PositionMapEvaluator.cpp @@ -0,0 +1,51 @@ +#include "PositionMapEvaluator.hpp" +#include +#include +#include +#include + + +namespace wmtk::function::utils { + +PositionMapEvaluator::PositionMapEvaluator() = default; +PositionMapEvaluator::~PositionMapEvaluator() = default; +PositionMapEvaluator::PositionMapEvaluator(PositionMapEvaluator&&) = + default; // move assignment operator +PositionMapEvaluator& PositionMapEvaluator::operator=(PositionMapEvaluator&&) = + default; // move assignment operator + +/** + * @brief Construct a new Dofs To Position object using a displacement map (requires a + * sampler) + * + * @param image + */ +PositionMapEvaluator::PositionMapEvaluator(const image::Image& image) +{ + m_sampling = std::make_unique(image); +} + +PositionMapEvaluator::PositionMapEvaluator( + const wmtk::image::SamplingAnalyticFunction::FunctionType type, + const double a, + const double b, + const double c) +{ + m_sampling = std::make_unique(type, a, b, c); +} +/* +template +Vector3 PositionMapEvaluator::uv_to_pos(const Vector2& uv) const +{ + return Vector3(uv.x(), uv.y(), m_sampling->sample(uv.x(), uv.y())); +} + +template <> +Vector3 PositionMapEvaluator::uv_to_pos(const Vector2& uv) const; + + +template <> +auto PositionMapEvaluator::uv_to_pos(const Vector2& uv) + const -> Vector3; + */ +} // namespace wmtk::function::utils diff --git a/src/wmtk/function/utils/PositionMapEvaluator.hpp b/src/wmtk/function/utils/PositionMapEvaluator.hpp new file mode 100644 index 0000000000..25846a6051 --- /dev/null +++ b/src/wmtk/function/utils/PositionMapEvaluator.hpp @@ -0,0 +1,46 @@ +#pragma once +#include +#include +#include + +namespace wmtk::image { +class Image; +class SamplingAnalyticFunction; +} // namespace wmtk::image + +namespace wmtk::function::utils { +class PositionMapEvaluator +{ +protected: + std::unique_ptr m_sampling; + +public: + PositionMapEvaluator(); + ~PositionMapEvaluator(); + PositionMapEvaluator(PositionMapEvaluator&&); // move assignment operator + PositionMapEvaluator& operator=(PositionMapEvaluator&&); // move assignment operator + + /** + * @brief Construct a new Dofs To Position object using a displacement map (requires a + * sampler) + * + * @param image + */ + PositionMapEvaluator(const image::Image& image); + + PositionMapEvaluator( + const wmtk::image::SamplingAnalyticFunction::FunctionType type, + const double a, + const double b, + const double c); + + + // Dont forget to update this if we change autodiff tyeps (add declarations in the cpp) + template + Vector3 uv_to_pos(const Vector2& uv) const +{ + return Vector3(uv.x(), uv.y(), m_sampling->sample(uv.x(), uv.y())); +} +}; + +} // namespace wmtk::function::utils diff --git a/src/wmtk/function/utils/amips.cpp b/src/wmtk/function/utils/amips.cpp new file mode 100644 index 0000000000..f3c3501eb3 --- /dev/null +++ b/src/wmtk/function/utils/amips.cpp @@ -0,0 +1,44 @@ +#include "amips.hpp" +#include +#include +#include +namespace wmtk::function::utils { +namespace detail { +namespace { +auto make_amips_target_triangle() +{ + const static std::array, 3> m_target_triangle{{ + // comments to keep formatting + std::array{{0., 0.}}, // + std::array{{1., 0.}}, // + std::array{{0.5, sqrt(3) / 2.}}, + // + }}; + auto map = Eigen::Matrix::ConstMapType(m_target_triangle[0].data()); + + +#if !defined(NDEBUG) + auto x = map.col(0); + auto y = map.col(1); + auto z = map.col(2); + assert(wmtk::utils::triangle_signed_2d_area(x,y,z) > 0); +#endif + return map; +} + +} // namespace +const Eigen::Matrix amips_target_triangle = make_amips_target_triangle(); + +namespace { +Eigen::Matrix2d make_amips_reference_to_barycentric() +{ + const auto& A = amips_target_triangle; + Eigen::Matrix2d Ds = (A.rightCols<2>().colwise() - A.col(0)); + + return Ds.inverse(); +} + +} // namespace +const Eigen::Matrix2d amips_reference_to_barycentric = make_amips_reference_to_barycentric(); +} // namespace detail +} // namespace wmtk::function::utils diff --git a/src/wmtk/function/utils/amips.hpp b/src/wmtk/function/utils/amips.hpp new file mode 100644 index 0000000000..70a0617d20 --- /dev/null +++ b/src/wmtk/function/utils/amips.hpp @@ -0,0 +1,109 @@ +#pragma once +#include +#include + +namespace wmtk::function::utils { + +namespace detail { +// returns v0,v1,v2 of the target triangle as row vectors +extern const Eigen::Matrix amips_target_triangle; +// maps from the embedding of the reference triangle to the barycentric coordinates +extern const Eigen::Matrix2d amips_reference_to_barycentric; + +} // namespace detail + + +// Given an basis vectors following a "v1-v0,v2-v0" convention for a triangle (v0,v1,v2) +// return the AMIPS energy. +// Input are assumed to be column vectors, opposite of the standard IGL formation +template +auto amips(const Eigen::MatrixBase& B) +{ + using Scalar = typename Derived::Scalar; + constexpr static int Rows = Derived::RowsAtCompileTime; + constexpr static int Cols = Derived::ColsAtCompileTime; + + + // check that these are vectors + static_assert(Cols == 2); + static_assert(Rows == 2); + + + // MTAO: Why can't this work with more than 2 rows? + // define of transform matrix F = Dm@Ds.inv + Eigen::Matrix J; + J = B * detail::amips_reference_to_barycentric.template cast(); + + auto Jdet = J.determinant(); + if (abs(Jdet) < std::numeric_limits::denorm_min()) { + return static_cast(std::numeric_limits::infinity()); + } + return (J * J.transpose()).trace() / Jdet; +} + + +// +template +auto amips( + const Eigen::MatrixBase& v0, + const Eigen::MatrixBase& v1, + const Eigen::MatrixBase& v2) +{ + using Scalar = typename V0Type::Scalar; + constexpr static int Rows = V0Type::RowsAtCompileTime; + constexpr static int Cols0 = V0Type::ColsAtCompileTime; + constexpr static int Cols1 = V1Type::ColsAtCompileTime; + constexpr static int Cols2 = V1Type::ColsAtCompileTime; + + + // check that these are vectors + static_assert(Cols0 == 1); + static_assert(Cols1 == 1); + static_assert(Cols2 == 1); + + // just check that the inputs had the right dimensions + constexpr static int Rows1 = V1Type::RowsAtCompileTime; + constexpr static int Rows2 = V1Type::RowsAtCompileTime; + static_assert(Rows == Rows1); + static_assert(Rows == Rows2); + + Eigen::Matrix Dm; + + + static_assert(Rows == 2 || Rows == 3); + + if constexpr (Rows == 2) { + Dm.col(0) = (v1.template cast() - v0); + Dm.col(1) = (v2.template cast() - v0); + } else if constexpr (Rows == 3) { + // in 3d we compute a basis + // local vectors + Eigen::Matrix V; + V.col(0) = v1.template cast() - v0; + V.col(1) = v2.template cast() - v0; + + // compute a basis plane + Eigen::Matrix B = V; + + auto e0 = B.col(0); + auto e1 = B.col(1); + + // TODO: shouldnt we make sure the normms are over some eps instead of 0? + auto e0norm = e0.norm(); + assert(e0norm > 0); // check norm is not 0 + e0 = e0 / e0norm; + + Vector3 n = e0.cross(e1); + e1 = n.cross(e0); + auto e1norm = e1.norm(); + assert(e1norm > 0); // check norm is not 0 + e1 = e1 / e1norm; + + + Dm = (B.transpose() * V).eval(); + } + + + return amips(Dm); +} +} // namespace wmtk::function::utils diff --git a/src/wmtk/function/utils/autodiff.cpp b/src/wmtk/function/utils/autodiff.cpp new file mode 100644 index 0000000000..f9e233ee0a --- /dev/null +++ b/src/wmtk/function/utils/autodiff.cpp @@ -0,0 +1,5 @@ +#include "autodiff.h" + + + +DECLARE_DIFFSCALAR_BASE(); diff --git a/src/wmtk/function/utils/autodiff.h b/src/wmtk/function/utils/autodiff.h new file mode 100644 index 0000000000..a3645ef60a --- /dev/null +++ b/src/wmtk/function/utils/autodiff.h @@ -0,0 +1,1099 @@ +// clang-format off + +/** + Automatic differentiation data type for C++, depends on the Eigen + linear algebra library. + + Copyright (c) 2012 by Wenzel Jakob. Based on code by Jon Kaldor + and Eitan Grinspun. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#ifndef __AUTODIFF_H +#define __AUTODIFF_H + +#ifndef EIGEN_DONT_PARALLELIZE +#define EIGEN_DONT_PARALLELIZE +#endif + +#include +#include +#include +#include + +/** + * \brief Base class of all automatic differentiation types + * + * This class records the number of independent variables with respect + * to which derivatives are computed. + */ +struct DiffScalarBase +{ + // ====================================================================== + /// @{ \name Configuration + // ====================================================================== + + /** + * \brief Set the independent variable count used by the automatic + * differentiation layer + * + * This function must be called before doing any computations with + * \ref DScalar1 or \ref DScalar2. The value will be recorded in + * thread-local storage. + */ + static inline void setVariableCount(size_t value) + { + m_variableCount = value; + } + + /// Get the variable count used by the automatic differentiation layer + static inline size_t getVariableCount() + { + return m_variableCount; + } + + /// @} + // ====================================================================== + + // #ifdef WIN32 + // static __declspec(thread) size_t m_variableCount; + // #else + // static __thread size_t m_variableCount; + // #endif + static thread_local size_t m_variableCount; +}; + +// #ifdef WIN32 +// #define DECLARE_DIFFSCALAR_BASE() +// __declspec(thread) size_t DiffScalarBase::m_variableCount = 0 +// #else +// #define DECLARE_DIFFSCALAR_BASE() +// __thread size_t DiffScalarBase::m_variableCount = 0 +// #endif + +#define DECLARE_DIFFSCALAR_BASE() \ + thread_local size_t DiffScalarBase::m_variableCount = 0 + +/** + * \brief Automatic differentiation scalar with first-order derivatives + * + * This class provides an instrumented "scalar" value, which may be dependent on + * a number of independent variables. The implementation keeps tracks of + * first -order drivatives with respect to these variables using a set + * of overloaded operations and implementations of special functions (sin, + * tan, exp, ..). + * + * This is extremely useful for numerical zero-finding, particularly when + * analytic derivatives from programs like Maple or Mathematica suffer from + * excessively complicated expressions. + * + * The class relies on templates, which makes it possible to fix the + * number of independent variables at compile-time so that instances can + * be allocated on the stack. Otherwise, they will be placed on the heap. + * + * This is an extended C++ port of Jon Kaldor's implementation, which is + * based on a C version by Eitan Grinspun at Caltech) + * + * \sa DScalar2 + * \author Wenzel Jakob + */ +template > +struct DScalar1 : public DiffScalarBase +{ +public: + typedef _Scalar Scalar; + typedef _Gradient Gradient; + typedef Eigen::Matrix DVector2; + typedef Eigen::Matrix DVector3; + + // ====================================================================== + /// @{ \name Constructors and accessors + // ====================================================================== + + /// Create a new constant automatic differentiation scalar + explicit DScalar1(Scalar value_ = (Scalar)0) : value(value_) + { + size_t variableCount = getVariableCount(); + grad.resize(variableCount); + grad.setZero(); + } + + /// Construct a new scalar with the specified value and one first derivative set to 1 + DScalar1(size_t index, const Scalar &value_) + : value(value_) + { + size_t variableCount = getVariableCount(); + grad.resize(variableCount); + grad.setZero(); + grad(index) = 1; + } + + /// Construct a scalar associated with the given gradient + DScalar1(Scalar value_, const Gradient &grad_) + : value(value_), grad(grad_) {} + + /// Copy constructor + DScalar1(const DScalar1 &s) + : value(s.value), grad(s.grad) {} + + inline const Scalar &getValue() const { return value; } + inline const Gradient &getGradient() const { return grad; } + + // ====================================================================== + /// @{ \name Addition + // ====================================================================== + friend DScalar1 operator+(const DScalar1 &lhs, const DScalar1 &rhs) + { + return DScalar1(lhs.value + rhs.value, lhs.grad + rhs.grad); + } + + friend DScalar1 operator+(const DScalar1 &lhs, const Scalar &rhs) + { + return DScalar1(lhs.value + rhs, lhs.grad); + } + + friend DScalar1 operator+(const Scalar &lhs, const DScalar1 &rhs) + { + return DScalar1(rhs.value + lhs, rhs.grad); + } + + inline DScalar1 &operator+=(const DScalar1 &s) + { + value += s.value; + grad += s.grad; + return *this; + } + + inline DScalar1 &operator+=(const Scalar &v) + { + value += v; + return *this; + } + + /// @} + // ====================================================================== + + // ====================================================================== + /// @{ \name Subtraction + // ====================================================================== + + friend DScalar1 operator-(const DScalar1 &lhs, const DScalar1 &rhs) + { + return DScalar1(lhs.value - rhs.value, lhs.grad - rhs.grad); + } + + friend DScalar1 operator-(const DScalar1 &lhs, const Scalar &rhs) + { + return DScalar1(lhs.value - rhs, lhs.grad); + } + + friend DScalar1 operator-(const Scalar &lhs, const DScalar1 &rhs) + { + return DScalar1(lhs - rhs.value, -rhs.grad); + } + + friend DScalar1 operator-(const DScalar1 &s) + { + return DScalar1(-s.value, -s.grad); + } + + inline DScalar1 &operator-=(const DScalar1 &s) + { + value -= s.value; + grad -= s.grad; + return *this; + } + + inline DScalar1 &operator-=(const Scalar &v) + { + value -= v; + return *this; + } + /// @} + // ====================================================================== + + // ====================================================================== + /// @{ \name Division + // ====================================================================== + friend DScalar1 operator/(const DScalar1 &lhs, const Scalar &rhs) + { + if (rhs == 0) + throw std::runtime_error("DScalar1: Division by zero!"); + Scalar inv = 1.0f / rhs; + return DScalar1(lhs.value * inv, lhs.grad * inv); + } + + friend DScalar1 operator/(const Scalar &lhs, const DScalar1 &rhs) + { + return lhs * inverse(rhs); + } + + friend DScalar1 operator/(const DScalar1 &lhs, const DScalar1 &rhs) + { + return lhs * inverse(rhs); + } + + friend DScalar1 inverse(const DScalar1 &s) + { + Scalar valueSqr = s.value * s.value, + invValueSqr = (Scalar)1 / valueSqr; + + // vn = 1/v, Dvn = -1/(v^2) Dv + return DScalar1((Scalar)1 / s.value, s.grad * -invValueSqr); + } + + inline DScalar1 &operator/=(const Scalar &v) + { + value /= v; + grad /= v; + return *this; + } + + /// @} + // ====================================================================== + + // ====================================================================== + /// @{ \name Multiplication + // ====================================================================== + inline friend DScalar1 operator*(const DScalar1 &lhs, const Scalar &rhs) + { + return DScalar1(lhs.value * rhs, lhs.grad * rhs); + } + + inline friend DScalar1 operator*(const Scalar &lhs, const DScalar1 &rhs) + { + return DScalar1(rhs.value * lhs, rhs.grad * lhs); + } + + inline friend DScalar1 operator*(const DScalar1 &lhs, const DScalar1 &rhs) + { + // Product rule + return DScalar1(lhs.value * rhs.value, + rhs.grad * lhs.value + lhs.grad * rhs.value); + } + + inline DScalar1 &operator*=(const Scalar &v) + { + value *= v; + grad *= v; + return *this; + } + + /// @} + // ====================================================================== + + // ====================================================================== + /// @{ \name Miscellaneous functions + // ====================================================================== + + friend DScalar1 abs(const DScalar1 &s) + { + return s.value < 0? -s: s; + } + + friend DScalar1 sqrt(const DScalar1 &s) + { + Scalar sqrtVal = std::sqrt(s.value), + temp = (Scalar)1 / ((Scalar)2 * sqrtVal); + + // vn = sqrt(v) + // Dvn = 1/(2 sqrt(v)) Dv + return DScalar1(sqrtVal, s.grad * temp); + } + + friend DScalar1 pow(const DScalar1 &s, const Scalar &a) + { + Scalar powVal = std::pow(s.value, a), + temp = a * std::pow(s.value, a - 1); + // vn = v ^ a, Dvn = a*v^(a-1) * Dv + return DScalar1(powVal, s.grad * temp); + } + + friend DScalar1 exp(const DScalar1 &s) + { + Scalar expVal = std::exp(s.value); + + // vn = exp(v), Dvn = exp(v) * Dv + return DScalar1(expVal, s.grad * expVal); + } + + friend DScalar1 log(const DScalar1 &s) + { + Scalar logVal = std::log(s.value); + + // vn = log(v), Dvn = Dv / v + return DScalar1(logVal, s.grad / s.value); + } + + friend DScalar1 sin(const DScalar1 &s) + { + // vn = sin(v), Dvn = cos(v) * Dv + return DScalar1(std::sin(s.value), s.grad * std::cos(s.value)); + } + + friend DScalar1 cos(const DScalar1 &s) + { + // vn = cos(v), Dvn = -sin(v) * Dv + return DScalar1(std::cos(s.value), s.grad * -std::sin(s.value)); + } + + friend DScalar1 acos(const DScalar1 &s) + { + if (std::abs(s.value) >= 1) + throw std::runtime_error("acos: Expected a value in (-1, 1)"); + + Scalar temp = -std::sqrt((Scalar)1 - s.value * s.value); + + // vn = acos(v), Dvn = -1/sqrt(1-v^2) * Dv + return DScalar1(std::acos(s.value), + s.grad * ((Scalar)1 / temp)); + } + + friend DScalar1 asin(const DScalar1 &s) + { + if (std::abs(s.value) >= 1) + throw std::runtime_error("asin: Expected a value in (-1, 1)"); + + Scalar temp = std::sqrt((Scalar)1 - s.value * s.value); + + // vn = asin(v), Dvn = 1/sqrt(1-v^2) * Dv + return DScalar1(std::asin(s.value), + s.grad * ((Scalar)1 / temp)); + } + + friend DScalar1 atan2(const DScalar1 &y, const DScalar1 &x) + { + Scalar denom = x.value * x.value + y.value * y.value; + + // vn = atan2(y, x), Dvn = (x*Dy - y*Dx) / (x^2 + y^2) + return DScalar1(std::atan2(y.value, x.value), + y.grad * (x.value / denom) - x.grad * (y.value / denom)); + } + + /// @} + // ====================================================================== + + // ====================================================================== + /// @{ \name Comparison and assignment + // ====================================================================== + + inline void operator=(const DScalar1 &s) + { + value = s.value; + grad = s.grad; + } + inline void operator=(const Scalar &v) + { + value = v; + grad.setZero(); + } + inline bool operator<(const DScalar1 &s) const { return value < s.value; } + inline bool operator<=(const DScalar1 &s) const { return value <= s.value; } + inline bool operator>(const DScalar1 &s) const { return value > s.value; } + inline bool operator>=(const DScalar1 &s) const { return value >= s.value; } + inline bool operator<(const Scalar &s) const { return value < s; } + inline bool operator<=(const Scalar &s) const { return value <= s; } + inline bool operator>(const Scalar &s) const { return value > s; } + inline bool operator>=(const Scalar &s) const { return value >= s; } + inline bool operator==(const Scalar &s) const { return value == s; } + inline bool operator!=(const Scalar &s) const { return value != s; } + + /// @} + // ====================================================================== + + // ====================================================================== + /// @{ \name Vector helper functions + // ====================================================================== + + /// Initialize a constant two-dimensional vector + static inline DVector2 vector(const Eigen::Matrix &v) + { + return DVector2(DScalar1(v.x()), DScalar1(v.y())); + } + + /// Create a constant three-dimensional vector + static inline DVector3 vector(const Eigen::Matrix &v) + { + return DVector3(DScalar1(v.x()), DScalar1(v.y()), DScalar1(v.z())); + } + +#if defined(__MITSUBA_MITSUBA_H_) /* Mitsuba-specific */ + /// Initialize a constant two-dimensional vector + static inline DVector2 vector(const mitsuba::TVector2 &v) + { + return DVector2(DScalar1(v.x), DScalar1(v.y)); + } + + /// Initialize a constant two-dimensional vector + static inline DVector2 vector(const mitsuba::TPoint2 &p) + { + return DVector2(DScalar1(p.x), DScalar1(p.y)); + } + + /// Create a constant three-dimensional vector + static inline DVector3 vector(const mitsuba::TVector3 &v) + { + return DVector3(DScalar1(v.x), DScalar1(v.y), DScalar1(v.z)); + } + + /// Create a constant three-dimensional vector + static inline DVector3 vector(const mitsuba::TPoint3 &p) + { + return DVector3(DScalar1(p.x), DScalar1(p.y), DScalar1(p.z)); + } +#endif + + /// @} + // ====================================================================== +protected: + Scalar value; + Gradient grad; +}; + +template +std::ostream &operator<<(std::ostream &out, const DScalar1 &s) +{ + out << "[" << s.getValue() + << ", grad=" << s.getGradient().format(Eigen::IOFormat(4, 1, ", ", "; ", "", "", "[", "]")) + << "]"; + return out; +} + +/** + * \brief Automatic differentiation scalar with first- and second-order derivatives + * + * This class provides an instrumented "scalar" value, which may be dependent on + * a number of independent variables. The implementation keeps tracks of first + * and second-order drivatives with respect to these variables using a set + * of overloaded operations and implementations of special functions (sin, + * tan, exp, ..). + * + * This is extremely useful for numerical optimization, particularly when + * analytic derivatives from programs like Maple or Mathematica suffer from + * excessively complicated expressions. + * + * The class relies on templates, which makes it possible to fix the + * number of independent variables at compile-time so that instances can + * be allocated on the stack. Otherwise, they will be placed on the heap. + * + * This is an extended C++ port of Jon Kaldor's implementation, which is + * based on a C version by Eitan Grinspun at Caltech) + * + * \sa DScalar1 + * \author Wenzel Jakob + */ +template , + typename _Hessian = Eigen::Matrix<_Scalar, Eigen::Dynamic, Eigen::Dynamic>> +struct DScalar2 : public DiffScalarBase +{ +public: + typedef _Scalar Scalar; + typedef _Gradient Gradient; + typedef _Hessian Hessian; + typedef Eigen::Matrix DVector2; + typedef Eigen::Matrix DVector3; + + // ====================================================================== + /// @{ \name Constructors and accessors + // ====================================================================== + + /// Create a new constant automatic differentiation scalar + explicit DScalar2(Scalar value_ = (Scalar)0) : value(value_) + { + size_t variableCount = getVariableCount(); + + grad.resize(variableCount); + grad.setZero(); + hess.resize(variableCount, variableCount); + hess.setZero(); + } + + /// Construct a new scalar with the specified value and one first derivative set to 1 + DScalar2(size_t index, const Scalar &value_) + : value(value_) + { + size_t variableCount = getVariableCount(); + + grad.resize(variableCount); + grad.setZero(); + grad(index) = 1; + hess.resize(variableCount, variableCount); + hess.setZero(); + } + + /// Construct a scalar associated with the given gradient and Hessian + DScalar2(Scalar value_, const Gradient &grad_, const Hessian &hess_) + : value(value_), grad(grad_), hess(hess_) {} + + /// Copy constructor + DScalar2(const DScalar2 &s) + : value(s.value), grad(s.grad), hess(s.hess) {} + + inline const Scalar &getValue() const { return value; } + inline const Gradient &getGradient() const { return grad; } + inline const Hessian &getHessian() const { return hess; } + + // ====================================================================== + /// @{ \name Addition + // ====================================================================== + friend DScalar2 operator+(const DScalar2 &lhs, const DScalar2 &rhs) + { + return DScalar2(lhs.value + rhs.value, + lhs.grad + rhs.grad, lhs.hess + rhs.hess); + } + + friend DScalar2 operator+(const DScalar2 &lhs, const Scalar &rhs) + { + return DScalar2(lhs.value + rhs, lhs.grad, lhs.hess); + } + + friend DScalar2 operator+(const Scalar &lhs, const DScalar2 &rhs) + { + return DScalar2(rhs.value + lhs, rhs.grad, rhs.hess); + } + + inline DScalar2 &operator+=(const DScalar2 &s) + { + value += s.value; + grad += s.grad; + hess += s.hess; + return *this; + } + + inline DScalar2 &operator+=(const Scalar &v) + { + value += v; + return *this; + } + + /// @} + // ====================================================================== + + // ====================================================================== + /// @{ \name Subtraction + // ====================================================================== + + friend DScalar2 operator-(const DScalar2 &lhs, const DScalar2 &rhs) + { + return DScalar2(lhs.value - rhs.value, lhs.grad - rhs.grad, lhs.hess - rhs.hess); + } + + friend DScalar2 operator-(const DScalar2 &lhs, const Scalar &rhs) + { + return DScalar2(lhs.value - rhs, lhs.grad, lhs.hess); + } + + friend DScalar2 operator-(const Scalar &lhs, const DScalar2 &rhs) + { + return DScalar2(lhs - rhs.value, -rhs.grad, -rhs.hess); + } + + friend DScalar2 operator-(const DScalar2 &s) + { + return DScalar2(-s.value, -s.grad, -s.hess); + } + + inline DScalar2 &operator-=(const DScalar2 &s) + { + value -= s.value; + grad -= s.grad; + hess -= s.hess; + return *this; + } + + inline DScalar2 &operator-=(const Scalar &v) + { + value -= v; + return *this; + } + /// @} + // ====================================================================== + + // ====================================================================== + /// @{ \name Division + // ====================================================================== + friend DScalar2 operator/(const DScalar2 &lhs, const Scalar &rhs) + { + if (rhs == 0) + throw std::runtime_error("DScalar2: Division by zero!"); + Scalar inv = 1.0f / rhs; + return DScalar2(lhs.value * inv, lhs.grad * inv, lhs.hess * inv); + } + + friend DScalar2 operator/(const Scalar &lhs, const DScalar2 &rhs) + { + return lhs * inverse(rhs); + } + + friend DScalar2 operator/(const DScalar2 &lhs, const DScalar2 &rhs) + { + return lhs * inverse(rhs); + } + + friend DScalar2 inverse(const DScalar2 &s) + { + Scalar valueSqr = s.value * s.value, + valueCub = valueSqr * s.value, + invValueSqr = (Scalar)1 / valueSqr; + + // vn = 1/v + DScalar2 result((Scalar)1 / s.value); + + // Dvn = -1/(v^2) Dv + result.grad = s.grad * -invValueSqr; + + // D^2vn = -1/(v^2) D^2v + 2/(v^3) Dv Dv^T + result.hess = s.hess * -invValueSqr; + result.hess += s.grad * s.grad.transpose() + * ((Scalar)2 / valueCub); + + return result; + } + + inline DScalar2 &operator/=(const Scalar &v) + { + value /= v; + grad /= v; + hess /= v; + return *this; + } + /// @} + // ====================================================================== + + // ====================================================================== + /// @{ \name Multiplication + // ====================================================================== + friend DScalar2 operator*(const DScalar2 &lhs, const Scalar &rhs) + { + return DScalar2(lhs.value * rhs, lhs.grad * rhs, lhs.hess * rhs); + } + + friend DScalar2 operator*(const Scalar &lhs, const DScalar2 &rhs) + { + return DScalar2(rhs.value * lhs, rhs.grad * lhs, rhs.hess * lhs); + } + + friend DScalar2 operator*(const DScalar2 &lhs, const DScalar2 &rhs) + { + DScalar2 result(lhs.value * rhs.value); + + /// Product rule + result.grad = rhs.grad * lhs.value + lhs.grad * rhs.value; + + // (i,j) = g*F_xixj + g*G_xixj + F_xi*G_xj + F_xj*G_xi + result.hess = rhs.hess * lhs.value; + result.hess += lhs.hess * rhs.value; + result.hess += lhs.grad * rhs.grad.transpose(); + result.hess += rhs.grad * lhs.grad.transpose(); + + return result; + } + + inline DScalar2 &operator*=(const Scalar &v) + { + value *= v; + grad *= v; + hess *= v; + return *this; + } + + /// @} + // ====================================================================== + + // ====================================================================== + /// @{ \name Miscellaneous functions + // ====================================================================== + + friend DScalar2 abs(const DScalar2 &s) + { + return s.value < 0? -s: s; + } + + friend DScalar2 sqrt(const DScalar2 &s) + { + Scalar sqrtVal = std::sqrt(s.value), + temp = (Scalar)1 / ((Scalar)2 * sqrtVal); + + // vn = sqrt(v) + DScalar2 result(sqrtVal); + + // Dvn = 1/(2 sqrt(v)) Dv + result.grad = s.grad * temp; + + // D^2vn = 1/(2 sqrt(v)) D^2v - 1/(4 v*sqrt(v)) Dv Dv^T + result.hess = s.hess * temp; + result.hess += s.grad * s.grad.transpose() + * (-(Scalar)1 / ((Scalar)4 * s.value * sqrtVal)); + + return result; + } + + friend DScalar2 pow(const DScalar2 &s, const Scalar &a) + { + Scalar powVal = std::pow(s.value, a), + temp = a * std::pow(s.value, a - 1); + // vn = v ^ a + DScalar2 result(powVal); + + // Dvn = a*v^(a-1) * Dv + result.grad = s.grad * temp; + + // D^2vn = a*v^(a-1) D^2v - 1/(4 v*sqrt(v)) Dv Dv^T + result.hess = s.hess * temp; + result.hess += s.grad * s.grad.transpose() + * (a * (a - 1) * std::pow(s.value, a - 2)); + + return result; + } + + friend DScalar2 exp(const DScalar2 &s) + { + Scalar expVal = std::exp(s.value); + + // vn = exp(v) + DScalar2 result(expVal); + + // Dvn = exp(v) * Dv + result.grad = s.grad * expVal; + + // D^2vn = exp(v) * Dv*Dv^T + exp(v) * D^2v + result.hess = (s.grad * s.grad.transpose() + + s.hess) + * expVal; + + return result; + } + + friend DScalar2 log(const DScalar2 &s) + { + Scalar logVal = std::log(s.value); + + // vn = log(v) + DScalar2 result(logVal); + + // Dvn = Dv / v + result.grad = s.grad / s.value; + + // D^2vn = (v*D^2v - Dv*Dv^T)/(v^2) + result.hess = s.hess / s.value - (s.grad * s.grad.transpose() / (s.value * s.value)); + + return result; + } + + friend DScalar2 sin(const DScalar2 &s) + { + Scalar sinVal = std::sin(s.value), + cosVal = std::cos(s.value); + + // vn = sin(v) + DScalar2 result(sinVal); + + // Dvn = cos(v) * Dv + result.grad = s.grad * cosVal; + + // D^2vn = -sin(v) * Dv*Dv^T + cos(v) * Dv^2 + result.hess = s.hess * cosVal; + result.hess += s.grad * s.grad.transpose() * -sinVal; + + return result; + } + + friend DScalar2 cos(const DScalar2 &s) + { + Scalar sinVal = std::sin(s.value), + cosVal = std::cos(s.value); + // vn = cos(v) + DScalar2 result(cosVal); + + // Dvn = -sin(v) * Dv + result.grad = s.grad * -sinVal; + + // D^2vn = -cos(v) * Dv*Dv^T - sin(v) * Dv^2 + result.hess = s.hess * -sinVal; + result.hess += s.grad * s.grad.transpose() * -cosVal; + + return result; + } + + friend DScalar2 acos(const DScalar2 &s) + { + if (std::abs(s.value) >= 1) + throw std::runtime_error("acos: Expected a value in (-1, 1)"); + + Scalar temp = -std::sqrt((Scalar)1 - s.value * s.value); + + // vn = acos(v) + DScalar2 result(std::acos(s.value)); + + // Dvn = -1/sqrt(1-v^2) * Dv + result.grad = s.grad * ((Scalar)1 / temp); + + // D^2vn = -1/sqrt(1-v^2) * D^2v - v/[(1-v^2)^(3/2)] * Dv*Dv^T + result.hess = s.hess * ((Scalar)1 / temp); + result.hess += s.grad * s.grad.transpose() + * s.value / (temp * temp * temp); + + return result; + } + + friend DScalar2 asin(const DScalar2 &s) + { + if (std::abs(s.value) >= 1) + throw std::runtime_error("asin: Expected a value in (-1, 1)"); + + Scalar temp = std::sqrt((Scalar)1 - s.value * s.value); + + // vn = asin(v) + DScalar2 result(std::asin(s.value)); + + // Dvn = 1/sqrt(1-v^2) * Dv + result.grad = s.grad * ((Scalar)1 / temp); + + // D^2vn = 1/sqrt(1-v*v) * D^2v + v/[(1-v^2)^(3/2)] * Dv*Dv^T + result.hess = s.hess * ((Scalar)1 / temp); + result.hess += s.grad * s.grad.transpose() + * s.value / (temp * temp * temp); + + return result; + } + + friend DScalar2 atan2(const DScalar2 &y, const DScalar2 &x) + { + // vn = atan2(y, x) + DScalar2 result(std::atan2(y.value, x.value)); + + // Dvn = (x*Dy - y*Dx) / (x^2 + y^2) + Scalar denom = x.value * x.value + y.value * y.value, + denomSqr = denom * denom; + result.grad = y.grad * (x.value / denom) + - x.grad * (y.value / denom); + + // D^2vn = (Dy*Dx^T + xD^2y - Dx*Dy^T - yD^2x) / (x^2+y^2) + // - [(x*Dy - y*Dx) * (2*x*Dx + 2*y*Dy)^T] / (x^2+y^2)^2 + result.hess = (y.hess * x.value + + y.grad * x.grad.transpose() + - x.hess * y.value + - x.grad * y.grad.transpose()) + / denom; + + result.hess -= + (y.grad * (x.value / denomSqr) - x.grad * (y.value / denomSqr)) * (x.grad * ((Scalar)2 * x.value) + y.grad * ((Scalar)2 * y.value)).transpose(); + + return result; + } + + /// @} + // ====================================================================== + + // ====================================================================== + /// @{ \name Comparison and assignment + // ====================================================================== + + inline void operator=(const DScalar2 &s) + { + value = s.value; + grad = s.grad; + hess = s.hess; + } + inline void operator=(const Scalar &v) + { + value = v; + grad.setZero(); + hess.setZero(); + } + inline bool operator<(const DScalar2 &s) const { return value < s.value; } + inline bool operator<=(const DScalar2 &s) const { return value <= s.value; } + inline bool operator>(const DScalar2 &s) const { return value > s.value; } + inline bool operator>=(const DScalar2 &s) const { return value >= s.value; } + inline bool operator!=(const DScalar2 &s) const { return value != s.value; } + + inline bool operator<(const Scalar &s) const { return value < s; } + inline bool operator<=(const Scalar &s) const { return value <= s; } + inline bool operator>(const Scalar &s) const { return value > s; } + inline bool operator>=(const Scalar &s) const { return value >= s; } + inline bool operator==(const Scalar &s) const { return value == s; } + inline bool operator!=(const Scalar &s) const { return value != s; } + + /// @} + // ====================================================================== + + // ====================================================================== + /// @{ \name Vector helper functions + // ====================================================================== + +#if defined(__MITSUBA_MITSUBA_H_) /* Mitsuba-specific */ + /// Initialize a constant two-dimensional vector + static inline DVector2 vector(const mitsuba::TVector2 &v) + { + return DVector2(DScalar2(v.x), DScalar2(v.y)); + } + + /// Initialize a constant two-dimensional vector + static inline DVector2 vector(const mitsuba::TPoint2 &p) + { + return DVector2(DScalar2(p.x), DScalar2(p.y)); + } + + /// Create a constant three-dimensional vector + static inline DVector3 vector(const mitsuba::TVector3 &v) + { + return DVector3(DScalar2(v.x), DScalar2(v.y), DScalar2(v.z)); + } + + /// Create a constant three-dimensional vector + static inline DVector3 vector(const mitsuba::TPoint3 &p) + { + return DVector3(DScalar2(p.x), DScalar2(p.y), DScalar2(p.z)); + } +#endif + + /// Initialize a constant two-dimensional vector + static inline DVector2 vector(const Eigen::Matrix &v) + { + return DVector2(DScalar2(v.x()), DScalar2(v.y())); + } + + /// Create a constant three-dimensional vector + static inline DVector3 vector(const Eigen::Matrix &v) + { + return DVector3(DScalar2(v.x()), DScalar2(v.y()), DScalar2(v.z())); + } + + /// @} + // ====================================================================== +protected: + Scalar value; + Gradient grad; + Hessian hess; +}; + +template +std::ostream &operator<<(std::ostream &out, const DScalar2 &s) +{ + out << "[" << s.getValue() + << ", grad=" << s.getGradient().format(Eigen::IOFormat(4, 1, ", ", "; ", "", "", "[", "]")) + << ", hess=" << s.getHessian().format(Eigen::IOFormat(4, 0, ", ", "; ", "", "", "[", "]")) + << "]"; + return out; +} + +// clang-format on + +namespace std { + +template +class numeric_limits> +{ +public: + static const bool is_specialized = std::numeric_limits<_Scalar>::is_specialized; + static const bool is_signed = std::numeric_limits<_Scalar>::is_signed; + static const bool is_integer = std::numeric_limits<_Scalar>::is_integer; + static const bool is_exact = std::numeric_limits<_Scalar>::is_exact; + static const int radix = std::numeric_limits<_Scalar>::radix; + static const bool has_infinity = std::numeric_limits<_Scalar>::has_infinity; + static const bool has_quiet_NaN = std::numeric_limits<_Scalar>::has_quiet_NaN; + static const bool has_signaling_NaN = std::numeric_limits<_Scalar>::has_signaling_NaN; + static const bool is_iec559 = std::numeric_limits<_Scalar>::is_iec559; + static const bool is_bounded = std::numeric_limits<_Scalar>::is_bounded; + static const bool is_modulo = std::numeric_limits<_Scalar>::is_modulo; + static const bool traps = std::numeric_limits<_Scalar>::traps; + static const bool tinyness_before = std::numeric_limits<_Scalar>::tinyness_before; + static const std::float_denorm_style has_denorm = std::numeric_limits<_Scalar>::has_denorm; + static const bool has_denorm_loss = std::numeric_limits<_Scalar>::has_denorm_loss; + static const int min_exponent = std::numeric_limits<_Scalar>::min_exponent; + static const int max_exponent = std::numeric_limits<_Scalar>::max_exponent; + static const int min_exponent10 = std::numeric_limits<_Scalar>::min_exponent10; + static const int max_exponent10 = std::numeric_limits<_Scalar>::max_exponent10; + static const std::float_round_style round_style = std::numeric_limits<_Scalar>::round_style; + static const int digits = std::numeric_limits<_Scalar>::digits; + static const int digits10 = std::numeric_limits<_Scalar>::digits10; + static const int max_digits10 = std::numeric_limits<_Scalar>::max_digits10; + + static constexpr _Scalar min() { return std::numeric_limits<_Scalar>::min(); } + + static constexpr _Scalar max() { return std::numeric_limits<_Scalar>::max(); } + + static constexpr _Scalar lowest() { return std::numeric_limits<_Scalar>::lowest(); } + + static constexpr _Scalar epsilon() { return std::numeric_limits<_Scalar>::epsilon(); } + + static constexpr _Scalar round_error() { return std::numeric_limits<_Scalar>::round_error(); } + + static constexpr _Scalar infinity() { return std::numeric_limits<_Scalar>::infinity(); } + + static constexpr _Scalar quiet_NaN() { return std::numeric_limits<_Scalar>::quiet_NaN(); } + + static constexpr _Scalar signaling_NaN() + { + return std::numeric_limits<_Scalar>::signaling_NaN(); + } + + static constexpr _Scalar denorm_min() { return std::numeric_limits<_Scalar>::denorm_min(); } +}; + +template +class numeric_limits> +{ +public: + static const bool is_specialized = std::numeric_limits<_Scalar>::is_specialized; + static const bool is_signed = std::numeric_limits<_Scalar>::is_signed; + static const bool is_integer = std::numeric_limits<_Scalar>::is_integer; + static const bool is_exact = std::numeric_limits<_Scalar>::is_exact; + static const int radix = std::numeric_limits<_Scalar>::radix; + static const bool has_infinity = std::numeric_limits<_Scalar>::has_infinity; + static const bool has_quiet_NaN = std::numeric_limits<_Scalar>::has_quiet_NaN; + static const bool has_signaling_NaN = std::numeric_limits<_Scalar>::has_signaling_NaN; + static const bool is_iec559 = std::numeric_limits<_Scalar>::is_iec559; + static const bool is_bounded = std::numeric_limits<_Scalar>::is_bounded; + static const bool is_modulo = std::numeric_limits<_Scalar>::is_modulo; + static const bool traps = std::numeric_limits<_Scalar>::traps; + static const bool tinyness_before = std::numeric_limits<_Scalar>::tinyness_before; + static const std::float_denorm_style has_denorm = std::numeric_limits<_Scalar>::has_denorm; + static const bool has_denorm_loss = std::numeric_limits<_Scalar>::has_denorm_loss; + static const int min_exponent = std::numeric_limits<_Scalar>::min_exponent; + static const int max_exponent = std::numeric_limits<_Scalar>::max_exponent; + static const int min_exponent10 = std::numeric_limits<_Scalar>::min_exponent10; + static const int max_exponent10 = std::numeric_limits<_Scalar>::max_exponent10; + static const std::float_round_style round_style = std::numeric_limits<_Scalar>::round_style; + static const int digits = std::numeric_limits<_Scalar>::digits; + static const int digits10 = std::numeric_limits<_Scalar>::digits10; + static const int max_digits10 = std::numeric_limits<_Scalar>::max_digits10; + + static constexpr _Scalar min() { return std::numeric_limits<_Scalar>::min(); } + + static constexpr _Scalar max() { return std::numeric_limits<_Scalar>::max(); } + + static constexpr _Scalar lowest() { return std::numeric_limits<_Scalar>::lowest(); } + + static constexpr _Scalar epsilon() { return std::numeric_limits<_Scalar>::epsilon(); } + + static constexpr _Scalar round_error() { return std::numeric_limits<_Scalar>::round_error(); } + + static constexpr _Scalar infinity() { return std::numeric_limits<_Scalar>::infinity(); } + + static constexpr _Scalar quiet_NaN() { return std::numeric_limits<_Scalar>::quiet_NaN(); } + + static constexpr _Scalar signaling_NaN() + { + return std::numeric_limits<_Scalar>::signaling_NaN(); + } + + static constexpr _Scalar denorm_min() { return std::numeric_limits<_Scalar>::denorm_min(); } +}; + +} // namespace std + +#endif /* __AUTODIFF_H */ diff --git a/src/wmtk/image/CMakeLists.txt b/src/wmtk/image/CMakeLists.txt new file mode 100644 index 0000000000..39824b44f6 --- /dev/null +++ b/src/wmtk/image/CMakeLists.txt @@ -0,0 +1,12 @@ +set(SRC_FILES + bicubic_interpolation.hpp + bicubic_interpolation.cpp + Image.cpp + Image.hpp + load_image_exr.cpp + load_image_exr.h + Sampling.hpp + save_image_exr.cpp + save_image_exr.h +) +target_sources(wildmeshing_toolkit PRIVATE ${SRC_FILES}) \ No newline at end of file diff --git a/src/wmtk/image/Image.cpp b/src/wmtk/image/Image.cpp new file mode 100644 index 0000000000..b0ca07fd3b --- /dev/null +++ b/src/wmtk/image/Image.cpp @@ -0,0 +1,292 @@ +#include +#include +#include +#include +#include +#include + +using namespace wmtk; + +using namespace image; +float modulo(double x, double n) +{ + float y = fmod(x, n); + if (y < 0) { + y += n; + } + return y; +} + +unsigned char double_to_unsignedchar(const double d) +{ + return round(std::max(std::min(1., d), 0.) * 255); +} + +int Image::get_coordinate(const int x, const WrappingMode mode) const +{ + auto size = std::max(width(), height()); + assert(-size < x && x < 2 * size); + switch (mode) { + case WrappingMode::REPEAT: return (x + size) % size; + + case WrappingMode::MIRROR_REPEAT: + if (x < 0) + return -(x % size); + else if (x < size) + return x; + else + return size - 1 - (x - size) % size; + + case WrappingMode::CLAMP_TO_EDGE: return std::clamp(x, 0, size - 1); + default: return (x + size) % size; + } +} + +std::pair Image::get_pixel_index(const double& u, const double& v) const +{ + int w = width(); + int h = height(); + auto size = std::max(w, h); + // x, y are between 0 and 1 + auto x = u * size; + auto y = v * size; + const auto sx = static_cast(std::floor(x - 0.5)); + const auto sy = static_cast(std::floor(y - 0.5)); + + return {sx, sy}; +} + +// set an image to have same value as the analytical function and save it to the file given +bool Image::set( + const std::function& f, + WrappingMode mode_x, + WrappingMode mode_y) +{ + int h = height(); + int w = width(); + + m_image.resize(h, w); + + for (int i = 0; i < h; i++) { + for (int j = 0; j < w; j++) { + double u, v; + u = (static_cast(j) + 0.5) / static_cast(w); + v = (static_cast(i) + 0.5) / static_cast(h); + m_image(i, j) = f(u, v); + } + } + set_wrapping_mode(mode_x, mode_y); + return true; +} + +// save to hdr or exr +bool Image::save(const std::filesystem::path& path) const +{ + wmtk::logger().trace("[save_image_hdr] start \"{}\"", path.string()); + int w = width(); + int h = height(); + std::vector buffer; + buffer.resize(w * h); + + for (auto i = 0; i < h; i++) { + for (auto j = 0; j < w; j++) { + buffer[i * w + j] = m_image(i, j); + } + } + if (path.extension() == ".hdr") { + auto res = stbi_write_hdr(path.string().c_str(), w, h, 1, buffer.data()); + assert(res); + } else if (path.extension() == ".exr") { + auto res = save_image_exr_red_channel(w, h, buffer, path); + } else { + wmtk::logger().trace("[save_image_hdr] format doesn't support \"{}\"", path.string()); + return false; + } + + wmtk::logger().trace("[save_image] done \"{}\"", path.string()); + + return true; +} + +// load from hdr or exr +void Image::load( + const std::filesystem::path& path, + const WrappingMode mode_x, + const WrappingMode mode_y) +{ + int w, h, channels; + channels = 1; + std::vector buffer; + if (path.extension() == ".exr") { + std::tie(w, h, buffer) = load_image_exr_red_channel(path); + assert(!buffer.empty()); + } else if (path.extension() == ".hdr") { + auto res = stbi_loadf(path.string().c_str(), &w, &h, &channels, 1); + buffer.assign(res, res + w * h); + } else { + wmtk::logger().trace("[load_image] format doesn't support \"{}\"", path.string()); + return; + } + + m_image.resize(w, h); + + for (int i = 0, k = 0; i < h; i++) { + for (int j = 0; j < w; j++) { + m_image(i, j) = buffer[k++]; + } + } + m_image.colwise().reverseInPlace(); + set_wrapping_mode(mode_x, mode_y); +} + +// down sample a image to size/2 by size/2 +// used for mipmap construction +Image Image::down_sample() const +{ + auto h = height(); + auto w = width(); + Image low_res_image(h / 2, w / 2); + for (int r = 0; r < h / 2; r++) { + for (int c = 0; c < w / 2; c++) { + low_res_image.set( + r, + c, + (m_image(r * 2, c * 2) + m_image(r * 2, c * 2 + 1) + m_image(r * 2 + 1, c * 2) + + m_image(r * 2 + 1, c * 2 + 1)) / + 4.); + } + } + return low_res_image; +} + +std::array combine_position_normal_texture( + double normalization_scale, + const Eigen::Matrix& offset, + const std::filesystem::path& position_path, + const std::filesystem::path& normal_path, + const std::filesystem::path& height_path, + float min_height, + float max_height) +{ + assert(std::filesystem::exists(position_path)); + auto [w_p, h_p, index_red_p, index_green_p, index_blue_p, buffer_r_p, buffer_g_p, buffer_b_p] = + load_image_exr_split_3channels(position_path); + + auto buffer_size = buffer_r_p.size(); + std::vector buffer_r_d(buffer_size); + std::vector buffer_g_d(buffer_size); + std::vector buffer_b_d(buffer_size); + + if (std::filesystem::exists(normal_path) && std::filesystem::exists(height_path)) { + // Load normal + heightmap and compute displaced positions. + auto + [w_n, + h_n, + index_red_n, + index_green_n, + index_blue_n, + buffer_r_n, + buffer_g_n, + buffer_b_n] = load_image_exr_split_3channels(normal_path); + auto + [w_h, + h_h, + index_red_h, + index_green_h, + index_blue_h, + buffer_r_h, + buffer_g_h, + buffer_b_h] = load_image_exr_split_3channels(height_path); + assert(buffer_r_p.size() == buffer_r_n.size()); + assert(buffer_r_p.size() == buffer_r_h.size()); + assert(buffer_r_p.size() == buffer_g_p.size()); + assert(buffer_r_p.size() == buffer_b_p.size()); + auto scale = [&](float h) { return min_height * (1.f - h) + max_height * h; }; + // displaced = positions * normalization_scale + heights * (2.0 * normals - 1.0) - offset + for (int i = 0; i < buffer_size; i++) { + buffer_r_d[i] = buffer_r_p[i] * normalization_scale + + scale(buffer_r_h[i]) * (2.0 * buffer_r_n[i] - 1.0) - offset[0]; + buffer_g_d[i] = buffer_g_p[i] * normalization_scale + + scale(buffer_g_h[i]) * (2.0 * buffer_g_n[i] - 1.0) - offset[1]; + buffer_b_d[i] = buffer_b_p[i] * normalization_scale + + scale(buffer_b_h[i]) * (2.0 * buffer_b_n[i] - 1.0) - offset[2]; + } + } else { + // Missing heightmap info: we use the position map as our displaced coordinates. + wmtk::logger().info("No heightmap provided: using positions as displaced coordinates."); + // displaced = positions * normalization_scale - offset + for (int i = 0; i < buffer_size; i++) { + buffer_r_d[i] = buffer_r_p[i] * normalization_scale - offset[0]; + buffer_g_d[i] = buffer_g_p[i] * normalization_scale - offset[1]; + buffer_b_d[i] = buffer_b_p[i] * normalization_scale - offset[2]; + } + } + + return { + buffer_to_image(buffer_r_d, w_p, h_p), + buffer_to_image(buffer_g_d, w_p, h_p), + buffer_to_image(buffer_b_d, w_p, h_p), + }; +} + +void split_and_save_3channels(const std::filesystem::path& path) +{ + int w, h, channels, index_red, index_blue, index_green; + channels = 1; + std::vector buffer_r, buffer_g, buffer_b; + if (path.extension() == ".exr") { + std::tie(w, h, index_red, index_green, index_blue, buffer_r, buffer_g, buffer_b) = + load_image_exr_split_3channels(path); + assert(!buffer_r.empty()); + assert(!buffer_g.empty()); + assert(!buffer_b.empty()); + } else { + spdlog::trace("[split_image] format doesn't support \"{}\"", path.string()); + return; + } + const std::filesystem::path directory = path.parent_path(); + const std::string file = path.stem().string(); + const std::filesystem::path path_r = directory / (file + "_r.exr"); + const std::filesystem::path path_g = directory / (file + "_g.exr"); + const std::filesystem::path path_b = directory / (file + "_b.exr"); + // just saves single channel data to red channel + auto res = save_image_exr_red_channel(w, h, buffer_r, path_r); + assert(res); + res = save_image_exr_red_channel(w, h, buffer_g, path_g); + assert(res); + res = save_image_exr_red_channel(w, h, buffer_b, path_b); + assert(res); +} + +Image buffer_to_image(const std::vector& buffer, int w, int h) +{ + Image image(w, h); + for (int i = 0, k = 0; i < h; i++) { + for (int j = 0; j < w; j++) { + image.set(h - i - 1, j, buffer[k++]); + } + } + return image; +} + +std::array load_rgb_image(const std::filesystem::path& path) +{ + int w, h, channels, index_red, index_blue, index_green; + channels = 1; + std::vector buffer_r, buffer_g, buffer_b; + if (path.extension() == ".exr") { + std::tie(w, h, index_red, index_green, index_blue, buffer_r, buffer_g, buffer_b) = + load_image_exr_split_3channels(path); + assert(!buffer_r.empty()); + assert(!buffer_g.empty()); + assert(!buffer_b.empty()); + } else { + wmtk::logger().error("[load_rgb_image] format doesn't support \"{}\"", path.string()); + exit(-1); + } + return { + wmtk::image::buffer_to_image(buffer_r, w, h), + wmtk::image::buffer_to_image(buffer_g, w, h), + wmtk::image::buffer_to_image(buffer_b, w, h), + }; +} diff --git a/src/wmtk/image/Image.hpp b/src/wmtk/image/Image.hpp new file mode 100644 index 0000000000..283aee4cc5 --- /dev/null +++ b/src/wmtk/image/Image.hpp @@ -0,0 +1,109 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "bicubic_interpolation.hpp" +#include "load_image_exr.h" +#include "save_image_exr.h" + +namespace wmtk { +namespace image { +class Image +{ + using DScalar = DScalar2, Eigen::Matrix>; + using ImageMatrixf = + Eigen::Matrix; + +protected: + ImageMatrixf m_image; // saving scanline images + WrappingMode m_mode_x = WrappingMode::CLAMP_TO_EDGE; + WrappingMode m_mode_y = WrappingMode::CLAMP_TO_EDGE; + +public: + Image() = default; + Image(int height_, int width_) { m_image.resize(height_, width_); }; + + ImageMatrixf& ref_raw_image() { return m_image; } + const ImageMatrixf& get_raw_image() const { return m_image; } + +public: + // point coordinates between [0, 1] + int width() const { return static_cast(m_image.cols()); }; + int height() const { return static_cast(m_image.rows()); }; + + template + std::decay_t get(const T& u, const T& v) const; + float get_pixel(const int i, const int j) const { return m_image(i, j); }; + std::pair get_pixel_index(const double& u, const double& v) const; + int get_coordinate(const int x, const WrappingMode mode) const; + WrappingMode get_wrapping_mode_x() const { return m_mode_x; }; + WrappingMode get_wrapping_mode_y() const { return m_mode_y; }; + bool set( + const std::function& f, + const WrappingMode mode_x = WrappingMode::CLAMP_TO_EDGE, + const WrappingMode mode_y = WrappingMode::CLAMP_TO_EDGE); + bool set(const int r, const int c, const float v) + { + m_image(r, c) = v; + return true; + }; + bool save(const std::filesystem::path& path) const; + void load(const std::filesystem::path& path, WrappingMode mode_x, WrappingMode mode_y); + + void set_wrapping_mode(WrappingMode mode_x, WrappingMode mode_y) + { + m_mode_x = mode_x; + m_mode_y = mode_y; + }; + Image down_sample() const; +}; + +/// @brief +/// @param p coordinates between (0,1) +/// @return / +template +std::decay_t Image::get(const T& u, const T& v) const +{ + int w = width(); + int h = height(); + auto size = std::max(w, h); + // x, y are between 0 and 1 + auto x = u * static_cast>(size); + auto y = v * static_cast>(size); + // use bicubic interpolation + + BicubicVector sample_vector = extract_samples( + static_cast(w), + static_cast(h), + m_image.data(), + get_value(x), + get_value(y), + m_mode_x, + m_mode_y); + BicubicVector bicubic_coeff = get_bicubic_matrix() * sample_vector; + return eval_bicubic_coeffs(bicubic_coeff, x, y); +}; + +void split_and_save_3channels(const std::filesystem::path& path); + +Image buffer_to_image(const std::vector& buffer, int w, int h); + +std::array load_rgb_image(const std::filesystem::path& path); + +std::array combine_position_normal_texture( + double normalization_scale, + const Eigen::Matrix& offset, + const std::filesystem::path& position_path, + const std::filesystem::path& normal_path, + const std::filesystem::path& texture_path, + float min_height = 0.f, + float max_height = 1.f); +} // namespace image +} // namespace wmtk \ No newline at end of file diff --git a/src/wmtk/image/Sampling.hpp b/src/wmtk/image/Sampling.hpp new file mode 100644 index 0000000000..4d24943cc4 --- /dev/null +++ b/src/wmtk/image/Sampling.hpp @@ -0,0 +1,133 @@ +#pragma once +#include +#include "Image.hpp" +#include "bicubic_interpolation.hpp" + +namespace wmtk { +namespace image { +enum class SAMPLING_MODE { BICUBIC, SPLINE }; +class Sampling +{ +public: + using DScalar = DScalar2, Eigen::Matrix>; + virtual ~Sampling(){}; + +public: + virtual double sample(const double u, const double v) const = 0; + virtual DScalar sample(const DScalar& u, const DScalar& v) const = 0; +}; + + +enum SamplingAnalyticFunction_FunctionType { Linear, Quadratic }; +class SamplingAnalyticFunction : public Sampling +{ +public: + using FunctionType = SamplingAnalyticFunction_FunctionType; + +protected: + FunctionType m_type = FunctionType::Linear; + double A = 0.0; + double B = 0.0; + double C = 0.0; + + template + auto evaluate(const S& u, const S& v) const + { + if (m_type == Linear) { + return evaluate_linear(u, v); + } else + return static_cast(0.0); + } + + template + auto evaluate_linear(const S& u, const S& v) const + { + return A * u + B * v + C; + } + +public: + // make a contructor + SamplingAnalyticFunction( + const FunctionType type, + const double a, + const double b, + const double c) + : m_type(type) + , A(a) + , B(b) + , C(c) + {} + + + void set_coefficients(double a, const double b, const double c) + { + A = a; + B = b; + C = c; + } + double sample(const double u, const double v) const override { return evaluate(u, v); } + DScalar sample(const DScalar& u, const DScalar& v) const override + { + return evaluate(u, v); + } +}; + +template +class SamplingImage : public Sampling +{ +protected: + const Image& m_image; + +public: + SamplingImage(const Image& img) + : m_image(img) + { + assert(m_image.width() == m_image.height()); + assert(m_image.width() != 0); + } + +public: + double sample(const double u, const double v) const override + { + return static_cast(this)->sample(u, v); + } + DScalar sample(const DScalar& u, const DScalar& v) const override + { + return static_cast(this)->sample(u, v); + } +}; + +class SamplingBicubic : public SamplingImage +{ +public: + using Super = SamplingImage; + using Super::Super; + template + T sample_T(T u, T v) const + { + auto w = m_image.width(); + auto h = m_image.height(); + // x, y are between 0 and 1 + T x = u * static_cast>(w); + T y = v * static_cast>(h); + + // use bicubic interpolation + BicubicVector sample_vector = extract_samples( + static_cast(w), + static_cast(h), + m_image.get_raw_image().data(), + wmtk::image::get_value(x), + wmtk::image::get_value(y), + m_image.get_wrapping_mode_x(), + m_image.get_wrapping_mode_y()); + BicubicVector bicubic_coeff = get_bicubic_matrix() * sample_vector; + return eval_bicubic_coeffs(bicubic_coeff, x, y); + } + double sample(const double u, const double v) const override { return sample_T(u, v); } + DScalar sample(const DScalar& u, const DScalar& v) const override + { + return sample_T(u, v); + } +}; +} // namespace image +} // namespace wmtk diff --git a/src/wmtk/image/bicubic_interpolation.cpp b/src/wmtk/image/bicubic_interpolation.cpp new file mode 100644 index 0000000000..1dd94584f4 --- /dev/null +++ b/src/wmtk/image/bicubic_interpolation.cpp @@ -0,0 +1,123 @@ +#include "bicubic_interpolation.hpp" + +#include +#include +using namespace wmtk; +using namespace wmtk::image; +wmtk::image::BicubicVector wmtk::image::extract_samples( + const size_t width, + const size_t height, + const float* buffer, + const double sx_, + const double sy_, + const WrappingMode mode_x, + const WrappingMode mode_y) +{ + BicubicVector samples; + + const auto get_coordinate = [](const int x, const int size, const WrappingMode mode) -> int { + switch (mode) { + case WrappingMode::REPEAT: return (x + size) % size; + + case WrappingMode::MIRROR_REPEAT: + if (x < 0) + return -(x % size); + else if (x < size) + return x; + else + return size - 1 - (x - size) % size; + case WrappingMode::CLAMP_TO_EDGE: return std::clamp(x, 0, size - 1); + default: return (x + size) % size; + } + }; + const auto get_buffer_value = [&](int xx, int yy) -> float { + xx = get_coordinate(xx, static_cast(width), mode_x); + yy = get_coordinate(yy, static_cast(height), mode_y); + const int index = (yy % height) * width + (xx % width); + return buffer[index]; + }; + + const auto sx = static_cast(std::floor(sx_ - 0.5)); + const auto sy = static_cast(std::floor(sy_ - 0.5)); + + samples(0) = get_buffer_value(sx - 1, sy - 1); + samples(1) = get_buffer_value(sx, sy - 1); + samples(2) = get_buffer_value(sx + 1, sy - 1); + samples(3) = get_buffer_value(sx + 2, sy - 1); + + samples(4) = get_buffer_value(sx - 1, sy); + samples(5) = get_buffer_value(sx, sy); + samples(6) = get_buffer_value(sx + 1, sy); + samples(7) = get_buffer_value(sx + 2, sy); + + samples(8) = get_buffer_value(sx - 1, sy + 1); + samples(9) = get_buffer_value(sx, sy + 1); + samples(10) = get_buffer_value(sx + 1, sy + 1); + samples(11) = get_buffer_value(sx + 2, sy + 1); + + samples(12) = get_buffer_value(sx - 1, sy + 2); + samples(13) = get_buffer_value(sx, sy + 2); + samples(14) = get_buffer_value(sx + 1, sy + 2); + samples(15) = get_buffer_value(sx + 2, sy + 2); + + return samples; +} + +wmtk::image::BicubicMatrix wmtk::image::make_samples_to_bicubic_coeffs_operator() +{ + BicubicMatrix ope; + Eigen::Index row = 0; + for (float yy = -1; yy < 3; yy++) + for (float xx = -1; xx < 3; xx++) { + ope(row, 0) = 1; + ope(row, 1) = xx; + ope(row, 2) = xx * xx; + ope(row, 3) = xx * xx * xx; + + ope(row, 4) = yy; + ope(row, 5) = xx * yy; + ope(row, 6) = xx * xx * yy; + ope(row, 7) = xx * xx * xx * yy; + + ope(row, 8) = yy * yy; + ope(row, 9) = xx * yy * yy; + ope(row, 10) = xx * xx * yy * yy; + ope(row, 11) = xx * xx * xx * yy * yy; + + ope(row, 12) = yy * yy * yy; + ope(row, 13) = xx * yy * yy * yy; + ope(row, 14) = xx * xx * yy * yy * yy; + ope(row, 15) = xx * xx * xx * yy * yy * yy; + + row++; + } + + { + std::stringstream ss; + ss << ope << std::endl; + spdlog::debug("ope det {}\n{}", ope.determinant(), ss.str()); + } + + // invert operator + BicubicMatrix ope_inv = ope.inverse(); + + // prune "zeros" + ope_inv = ope_inv.unaryExpr([](const float& xx) { return fabs(xx) < 1e-5f ? 0 : xx; }); + + { + std::stringstream ss; + ss << ope_inv << std::endl; + spdlog::debug("ope_inv det {}\n{}", ope_inv.determinant(), ss.str()); + } + + // double check inverse property + assert((ope * ope_inv - BicubicMatrix::Identity()).array().abs().maxCoeff() < 1e-5); + + return ope_inv; +} + +const wmtk::image::BicubicMatrix& wmtk::image::get_bicubic_matrix() +{ + static BicubicMatrix mat = wmtk::image::make_samples_to_bicubic_coeffs_operator(); + return mat; +} diff --git a/src/wmtk/image/bicubic_interpolation.hpp b/src/wmtk/image/bicubic_interpolation.hpp new file mode 100644 index 0000000000..aa7c14e138 --- /dev/null +++ b/src/wmtk/image/bicubic_interpolation.hpp @@ -0,0 +1,83 @@ +#pragma once + +#include +#include +#include +#include +enum class WrappingMode { REPEAT, MIRROR_REPEAT, CLAMP_TO_EDGE }; +namespace wmtk { +namespace image { +inline double get_value(float x) +{ + return static_cast(x); +} +inline double get_value(double x) +{ + return x; +} +inline double get_value( + DScalar2, Eigen::Matrix> x) +{ + return x.getValue(); +} + +inline double get_value( + DScalar2, Eigen::Matrix> x) +{ + return x.getValue(); +} + +template +using BicubicVector = Eigen::Matrix; +using BicubicMatrix = Eigen::Matrix; + +BicubicVector extract_samples( + const size_t width, + const size_t height, + const float* buffer, + const double xx, + const double yy, + const WrappingMode mode_x, + const WrappingMode mode_y); + +BicubicMatrix make_samples_to_bicubic_coeffs_operator(); + +const BicubicMatrix& get_bicubic_matrix(); + +template +std::decay_t eval_bicubic_coeffs(const BicubicVector& coeffs, const T& sx, const T& sy) +{ + using ImageScalar = std::decay_t; + + const auto xx = sx - (floor(get_value(sx) - 0.5f) + 0.5f); + const auto yy = sy - (floor(get_value(sy) - 0.5f) + 0.5f); + assert(0 <= get_value(xx) && get_value(xx) < 1); + assert(0 <= get_value(yy) && get_value(yy) < 1); + + BicubicVector vv; + + vv(0) = 1; + vv(1) = xx; + vv(2) = xx * xx; + vv(3) = xx * xx * xx; + + vv(4) = yy; + vv(5) = xx * yy; + vv(6) = xx * xx * yy; + vv(7) = xx * xx * xx * yy; + + vv(8) = yy * yy; + vv(9) = xx * yy * yy; + vv(10) = xx * xx * yy * yy; + vv(11) = xx * xx * xx * yy * yy; + + vv(12) = yy * yy * yy; + vv(13) = xx * yy * yy * yy; + vv(14) = xx * xx * yy * yy * yy; + vv(15) = xx * xx * xx * yy * yy * yy; + + return coeffs.cast().dot(vv); +} +} // namespace image + +} // namespace wmtk diff --git a/src/wmtk/image/load_image_exr.cpp b/src/wmtk/image/load_image_exr.cpp new file mode 100644 index 0000000000..53ebada184 --- /dev/null +++ b/src/wmtk/image/load_image_exr.cpp @@ -0,0 +1,336 @@ +#include "load_image_exr.h" + +#include +#define TINYEXR_USE_MINIZ 0 +#define TINYEXR_USE_STB_ZLIB 1 +#define TINYEXR_IMPLEMENTATION +#include +#include +#include +using namespace wmtk; +using namespace wmtk::image; +auto load_image_exr_red_channel(const std::filesystem::path& path) + -> std::tuple> +{ + using namespace wmtk; + wmtk::logger().debug("[load_image_exr_red_channel] start \"{}\"", path.string()); + assert(std::filesystem::exists(path)); + const std::string filename_ = path.string(); + const char* filename = filename_.c_str(); + + const auto exr_version = [&filename, &path]() -> EXRVersion { // parse version + EXRVersion exr_version_; + + const auto ret = ParseEXRVersionFromFile(&exr_version_, filename); + if (ret != TINYEXR_SUCCESS) { + wmtk::logger().error("failed LoadImageEXR \"{}\" \"version error\"", path.string()); + throw std::runtime_error("LoadImageEXRError"); + } + + if (exr_version_.multipart || exr_version_.non_image) { + wmtk::logger().error( + "failed LoadImageEXR \"{}\" \"multipart or non image\"", + path.string()); + throw std::runtime_error("LoadImageEXRError"); + } + + return exr_version_; + }(); + + auto exr_header_data = + [&filename, &path, &exr_version]() -> std::tuple { // parse header + EXRHeader exr_header_; + InitEXRHeader(&exr_header_); + + [[maybe_unused]] const char* err = nullptr; + const auto ret = ParseEXRHeaderFromFile(&exr_header_, &exr_version, filename, &err); + if (ret != TINYEXR_SUCCESS) { + wmtk::logger().error("failed LoadImageEXR \"{}\" \"header error\"", path.string()); + FreeEXRHeader(&exr_header_); + throw std::runtime_error("LoadImageEXRError"); + } + + // sanity check, only support all channels are the same type + for (int i = 0; i < exr_header_.num_channels; i++) { + if (exr_header_.pixel_types[i] != exr_header_.pixel_types[0] || + exr_header_.requested_pixel_types[i] != exr_header_.pixel_types[i]) { + wmtk::logger().error( + "failed LoadImageEXR \"{}\" \"inconsistent pixel_types\"", + path.string()); + FreeEXRHeader(&exr_header_); + throw std::runtime_error("LoadImageEXRError"); + } + } + + // read HALF channel as FLOAT. + for (int i = 0; i < exr_header_.num_channels; i++) { + if (exr_header_.pixel_types[i] == TINYEXR_PIXELTYPE_HALF) { + exr_header_.requested_pixel_types[i] = TINYEXR_PIXELTYPE_FLOAT; + } + } + + // only FLOAT are supported + if (exr_header_.requested_pixel_types[0] != TINYEXR_PIXELTYPE_FLOAT) { + wmtk::logger().error( + "failed LoadImageEXR \"{}\" \"only float exr are supported\"", + path.string()); + FreeEXRHeader(&exr_header_); + throw std::runtime_error("LoadImageEXRError"); + } + + // only non tiled image are supported + if (exr_header_.tiled) { + wmtk::logger().error( + "failed LoadImageEXR \"{}\" \"only non tiled exr are supported\"", + path.string()); + FreeEXRHeader(&exr_header_); + throw std::runtime_error("LoadImageEXRError"); + } + + + int index_red_ = -1; + for (int i = 0; i < exr_header_.num_channels; i++) { + if (strcmp(exr_header_.channels[i].name, "R") == 0) index_red_ = i; + } + if (index_red_ < 0) { + wmtk::logger().warn("Could not find R channel. Looking for Y channel instead."); + for (int i = 0; i < exr_header_.num_channels; i++) { + if (strcmp(exr_header_.channels[i].name, "Y") == 0) index_red_ = i; + } + } + + if (index_red_ < 0) { + std::vector channels; + for (int i = 0; i < exr_header_.num_channels; i++) { + channels.push_back(exr_header_.channels[i].name); + } + wmtk::logger().error( + "failed LoadImageEXR \"{}\" can't find all expected channels: [{}]", + path.string(), + fmt::join(channels, ",")); + FreeEXRHeader(&exr_header_); + throw std::runtime_error("LoadImageEXRError"); + } + + return {exr_header_, index_red_}; + }(); + auto& exr_header = std::get<0>(exr_header_data); + const auto& index_data = std::get<1>(exr_header_data); + + auto exr_image = [&filename, &path, &exr_header]() -> EXRImage { + EXRImage exr_image_; + InitEXRImage(&exr_image_); + + [[maybe_unused]] const char* err = nullptr; + const auto ret = LoadEXRImageFromFile(&exr_image_, &exr_header, filename, &err); + if (ret != TINYEXR_SUCCESS) { + wmtk::logger().error( + "failed LoadImageEXR \"{}\" \"failed to load image data\"", + path.string()); + FreeEXRHeader(&exr_header); + FreeEXRImage(&exr_image_); + throw std::runtime_error("LoadImageEXRError"); + } + + return exr_image_; + }(); + + wmtk::logger().debug( + "[load_image_exr_red_channel] num_channels {} tiled {}", + exr_header.num_channels, + exr_header.tiled); + wmtk::logger().debug("[load_image_exr_red_channel] index_data {}", index_data); + assert(index_data >= 0); + assert(!exr_header.tiled); + + std::vector data_r; + data_r.reserve(static_cast(exr_image.width) * static_cast(exr_image.height)); + + const auto images = reinterpret_cast(exr_image.images); + for (int i = 0; i < exr_image.width * exr_image.height; i++) + data_r.emplace_back(images[index_data][i]); + + FreeEXRHeader(&exr_header); + FreeEXRImage(&exr_image); + + wmtk::logger().debug("[load_image_exr_red_channel] done \"{}\"", path.string()); + + return { + static_cast(exr_image.width), + static_cast(exr_image.height), + std::move(data_r), + }; +} + +auto load_image_exr_split_3channels(const std::filesystem::path& path) -> std:: + tuple, std::vector, std::vector> +{ + using namespace wmtk; + wmtk::logger().debug("[load_image_exr_red_channel] start \"{}\"", path.string()); + assert(std::filesystem::exists(path)); + const std::string filename_ = path.string(); + const char* filename = filename_.c_str(); + + const auto exr_version = [&filename, &path]() -> EXRVersion { // parse version + EXRVersion exr_version_; + + const auto ret = ParseEXRVersionFromFile(&exr_version_, filename); + if (ret != TINYEXR_SUCCESS) { + wmtk::logger().error("failed LoadImageEXR \"{}\" \"version error\"", path.string()); + throw std::runtime_error("LoadImageEXRError"); + } + + if (exr_version_.multipart || exr_version_.non_image) { + wmtk::logger().error( + "failed LoadImageEXR \"{}\" \"multipart or non image\"", + path.string()); + throw std::runtime_error("LoadImageEXRError"); + } + + return exr_version_; + }(); + + auto exr_header_data = + [&filename, &path, &exr_version]() -> std::tuple { // parse header + EXRHeader exr_header_; + InitEXRHeader(&exr_header_); + + [[maybe_unused]] const char* err = nullptr; + const auto ret = ParseEXRHeaderFromFile(&exr_header_, &exr_version, filename, &err); + if (ret != TINYEXR_SUCCESS) { + wmtk::logger().error("failed LoadImageEXR \"{}\" \"header error\"", path.string()); + FreeEXRHeader(&exr_header_); + throw std::runtime_error("LoadImageEXRError"); + } + + // sanity check, only support all channels are the same type + for (int i = 0; i < exr_header_.num_channels; i++) { + if (exr_header_.pixel_types[i] != exr_header_.pixel_types[0] || + exr_header_.requested_pixel_types[i] != exr_header_.pixel_types[i]) { + wmtk::logger().error( + "failed LoadImageEXR \"{}\" \"inconsistent pixel_types\"", + path.string()); + FreeEXRHeader(&exr_header_); + throw std::runtime_error("LoadImageEXRError"); + } + } + + // read HALF channel as FLOAT. + for (int i = 0; i < exr_header_.num_channels; i++) { + if (exr_header_.pixel_types[i] == TINYEXR_PIXELTYPE_HALF) { + exr_header_.requested_pixel_types[i] = TINYEXR_PIXELTYPE_FLOAT; + } + } + + // only FLOAT are supported + if (exr_header_.requested_pixel_types[0] != TINYEXR_PIXELTYPE_FLOAT) { + wmtk::logger().error( + "failed LoadImageEXR \"{}\" \"only float exr are supported\"", + path.string()); + FreeEXRHeader(&exr_header_); + throw std::runtime_error("LoadImageEXRError"); + } + + // only non tiled image are supported + if (exr_header_.tiled) { + wmtk::logger().error( + "failed LoadImageEXR \"{}\" \"only non tiled exr are supported\"", + path.string()); + FreeEXRHeader(&exr_header_); + throw std::runtime_error("LoadImageEXRError"); + } + + + int index_red_ = -1; + int index_green_ = -1; + int index_blue_ = -1; + if (exr_header_.num_channels == 1) { + wmtk::logger().warn("Treat grayscale image as RGB: {}", path.string()); + index_red_ = 0; + index_green_ = 0; + index_blue_ = 0; + } else { + for (int i = 0; i < exr_header_.num_channels; i++) { + if (strcmp(exr_header_.channels[i].name, "R") == 0) index_red_ = i; + if (strcmp(exr_header_.channels[i].name, "G") == 0) index_green_ = i; + if (strcmp(exr_header_.channels[i].name, "B") == 0) index_blue_ = i; + } + } + + if (index_red_ < 0) { + std::vector channels; + for (int i = 0; i < exr_header_.num_channels; i++) { + channels.push_back(exr_header_.channels[i].name); + } + wmtk::logger().error( + "failed LoadImageEXR \"{}\" can't find all 3 expected channels: [{}]", + path.string(), + fmt::join(channels, ",")); + FreeEXRHeader(&exr_header_); + throw std::runtime_error("LoadImageEXRError"); + } + + return {exr_header_, index_red_, index_green_, index_blue_}; + }(); + auto& exr_header = std::get<0>(exr_header_data); + const auto& index_red = std::get<1>(exr_header_data); + const auto& index_green = std::get<2>(exr_header_data); + const auto& index_blue = std::get<3>(exr_header_data); + + auto exr_image = [&filename, &path, &exr_header]() -> EXRImage { + EXRImage exr_image_; + InitEXRImage(&exr_image_); + + [[maybe_unused]] const char* err = nullptr; + const auto ret = LoadEXRImageFromFile(&exr_image_, &exr_header, filename, &err); + if (ret != TINYEXR_SUCCESS) { + wmtk::logger().error( + "failed LoadImageEXR \"{}\" \"failed to load image data\"", + path.string()); + FreeEXRHeader(&exr_header); + FreeEXRImage(&exr_image_); + throw std::runtime_error("LoadImageEXRError"); + } + + return exr_image_; + }(); + + wmtk::logger().debug( + "[load_image_exr_3channels] num_channels {} tiled {}", + exr_header.num_channels, + exr_header.tiled); + wmtk::logger().debug("[load_image_exr_3channels] index_red {}", index_red); + wmtk::logger().debug("[load_image_exr_3channels] index_green {}", index_green); + wmtk::logger().debug("[load_image_exr_3channels] index_blue {}", index_blue); + assert(index_red >= 0); + assert(index_green >= 0); + assert(index_blue >= 0); + assert(!exr_header.tiled); + + std::vector data_r; + std::vector data_g; + std::vector data_b; + data_r.reserve(static_cast(exr_image.width) * static_cast(exr_image.height)); + data_g.reserve(static_cast(exr_image.width) * static_cast(exr_image.height)); + data_b.reserve(static_cast(exr_image.width) * static_cast(exr_image.height)); + + const auto images = reinterpret_cast(exr_image.images); + for (int i = 0; i < exr_image.width * exr_image.height; i++) { + data_r.emplace_back(images[index_red][i]); + data_g.emplace_back(images[index_green][i]); + data_b.emplace_back(images[index_blue][i]); + } + FreeEXRHeader(&exr_header); + FreeEXRImage(&exr_image); + + wmtk::logger().debug("[load_image_exr_3channels] done \"{}\"", path.string()); + + return {static_cast(exr_image.width), + static_cast(exr_image.height), + std::move(index_red), + std::move(index_green), + std::move(index_blue), + std::move(data_r), + std::move(data_g), + std::move(data_b)}; +} diff --git a/src/wmtk/image/load_image_exr.h b/src/wmtk/image/load_image_exr.h new file mode 100644 index 0000000000..71b8f3f6e2 --- /dev/null +++ b/src/wmtk/image/load_image_exr.h @@ -0,0 +1,22 @@ +#pragma once +#include +#include +#include +#include +namespace wmtk { +namespace image { + +auto load_image_exr_red_channel(const std::filesystem::path& path) + -> std::tuple>; + +auto load_image_exr_split_3channels(const std::filesystem::path& path) -> std::tuple< + size_t, + size_t, + int, + int, + int, + std::vector, + std::vector, + std::vector>; +} // namespace image +} // namespace wmtk \ No newline at end of file diff --git a/src/wmtk/image/save_image_exr.cpp b/src/wmtk/image/save_image_exr.cpp new file mode 100644 index 0000000000..61fa4ab77b --- /dev/null +++ b/src/wmtk/image/save_image_exr.cpp @@ -0,0 +1,147 @@ +#include "save_image_exr.h" +#include +#include +#include +using namespace wmtk; +using namespace wmtk::image; +bool save_image_exr_red_channel( + size_t width, + size_t height, + const std::vector& data, + const std::filesystem::path& path) +{ + EXRHeader header; + InitEXRHeader(&header); + + EXRImage image; + InitEXRImage(&image); + + image.num_channels = 3; + + std::vector images[3]; + images[0].resize(width * height); + images[1].resize(width * height); + images[2].resize(width * height); + + // Split RGBRGBRGB... into R, G and B layer + for (int i = 0; i < width * height; i++) { + images[0][i] = data[i]; + images[1][i] = -1.; + images[2][i] = -1.; + } + + float* image_ptr[3]; + image_ptr[0] = &(images[2].at(0)); // B + image_ptr[1] = &(images[1].at(0)); // G + image_ptr[2] = &(images[0].at(0)); // R + + image.images = (unsigned char**)image_ptr; + image.width = width; + image.height = height; + + header.num_channels = 3; + header.channels = (EXRChannelInfo*)malloc(sizeof(EXRChannelInfo) * header.num_channels); + // Must be (A)BGR order, since most of EXR viewers expect this channel order. + strncpy(header.channels[0].name, "B", 255); + header.channels[0].name[strlen("B")] = '\0'; + strncpy(header.channels[1].name, "G", 255); + header.channels[1].name[strlen("G")] = '\0'; + strncpy(header.channels[2].name, "R", 255); + header.channels[2].name[strlen("R")] = '\0'; + + header.pixel_types = (int*)malloc(sizeof(int) * header.num_channels); + header.requested_pixel_types = (int*)malloc(sizeof(int) * header.num_channels); + for (int i = 0; i < header.num_channels; i++) { + header.pixel_types[i] = TINYEXR_PIXELTYPE_FLOAT; // pixel type of input image + header.requested_pixel_types[i] = + TINYEXR_PIXELTYPE_HALF; // pixel type of output image to be stored in .EXR + } + + const char* err = NULL; // or nullptr in C++11 or later. + int ret = SaveEXRImageToFile(&image, &header, path.string().data(), &err); + if (ret != TINYEXR_SUCCESS) { + fprintf(stderr, "Save EXR err: %s\n", err); + FreeEXRErrorMessage(err); // free's buffer for an error message + return ret; + } + wmtk::logger().debug("Saved exr file. {} ", path); + + free(header.channels); + free(header.pixel_types); + free(header.requested_pixel_types); + return 0; +} + +bool save_image_exr_3channels( + size_t width, + size_t height, + int r, + int g, + int b, + const std::vector& data_r, + const std::vector& data_g, + const std::vector& data_b, + const std::filesystem::path& path) +{ + EXRHeader header; + InitEXRHeader(&header); + + EXRImage image; + InitEXRImage(&image); + + image.num_channels = 3; + + std::vector images[3]; + images[0].resize(width * height); + images[1].resize(width * height); + images[2].resize(width * height); + + // Split RGBRGBRGB... into R, G and B layer + for (int i = 0; i < width * height; i++) { + images[r][i] = data_r[i]; + images[g][i] = data_g[i]; + images[b][i] = data_b[i]; + } + wmtk::logger() + .info("[save r {} {}, g {} {} b {} {}]", r, images[r][0], g, images[g][0], b, images[b][0]); + float* image_ptr[3]; + image_ptr[0] = &(images[2].at(0)); // B + image_ptr[1] = &(images[1].at(0)); // G + image_ptr[2] = &(images[0].at(0)); // R + + image.images = (unsigned char**)image_ptr; + image.width = width; + image.height = height; + + header.num_channels = 3; + header.channels = (EXRChannelInfo*)malloc(sizeof(EXRChannelInfo) * header.num_channels); + // Must be (A)BGR order, since most of EXR viewers expect this channel order. + strncpy(header.channels[0].name, "B", 255); + header.channels[0].name[strlen("B")] = '\0'; + strncpy(header.channels[1].name, "G", 255); + header.channels[1].name[strlen("G")] = '\0'; + strncpy(header.channels[2].name, "R", 255); + header.channels[2].name[strlen("R")] = '\0'; + + header.pixel_types = (int*)malloc(sizeof(int) * header.num_channels); + header.requested_pixel_types = (int*)malloc(sizeof(int) * header.num_channels); + for (int i = 0; i < header.num_channels; i++) { + header.pixel_types[i] = TINYEXR_PIXELTYPE_FLOAT; // pixel type of input image + header.requested_pixel_types[i] = + TINYEXR_PIXELTYPE_HALF; // pixel type of output image to be stored in .EXR + } + + const char* err = NULL; // or nullptr in C++11 or later. + int ret = SaveEXRImageToFile(&image, &header, path.string().data(), &err); + if (ret != TINYEXR_SUCCESS) { + fprintf(stderr, "Save EXR err: %s\n", err); + FreeEXRErrorMessage(err); // free's buffer for an error message + return ret; + } + printf("Saved exr file 3 channels. [ %s ] \n", path.c_str()); + + free(header.channels); + free(header.pixel_types); + free(header.requested_pixel_types); + return 0; +} diff --git a/src/wmtk/image/save_image_exr.h b/src/wmtk/image/save_image_exr.h new file mode 100644 index 0000000000..cfde4a476a --- /dev/null +++ b/src/wmtk/image/save_image_exr.h @@ -0,0 +1,28 @@ +#pragma once +#include +#include +#include +#include +#include +#define TINYEXR_USE_MINIZ 0 +#define TINYEXR_USE_STB_ZLIB 1 +// #define TINYEXR_IMPLEMENTATION +namespace wmtk { +namespace image { +bool save_image_exr_red_channel( + size_t weigth, + size_t height, + const std::vector& data, + const std::filesystem::path& path); +bool save_image_exr_3channels( + size_t width, + size_t height, + int r, + int g, + int b, + const std::vector& data_r, + const std::vector& data_g, + const std::vector& data_b, + const std::filesystem::path& path); +} // namespace image +} // namespace wmtk \ No newline at end of file diff --git a/src/wmtk/invariants/CMakeLists.txt b/src/wmtk/invariants/CMakeLists.txt index 1b769de897..c6fec602a9 100644 --- a/src/wmtk/invariants/CMakeLists.txt +++ b/src/wmtk/invariants/CMakeLists.txt @@ -25,6 +25,8 @@ set(SRC_FILES MaxEdgeLengthInvariant.cpp MinEdgeLengthInvariant.hpp MinEdgeLengthInvariant.cpp + TriangleInversionInvariant.hpp + TriangleInversionInvariant.cpp TodoInvariant.hpp TodoInvariant.cpp ) diff --git a/src/wmtk/invariants/TriangleInversionInvariant.cpp b/src/wmtk/invariants/TriangleInversionInvariant.cpp new file mode 100644 index 0000000000..a252df69ae --- /dev/null +++ b/src/wmtk/invariants/TriangleInversionInvariant.cpp @@ -0,0 +1,28 @@ + +#include "TriangleInversionInvariant.hpp" +#include +#include + +namespace wmtk { +TriangleInversionInvariant::TriangleInversionInvariant( + const Mesh& m, + const MeshAttributeHandle& uv_coordinate) + : MeshInvariant(m) + , m_uv_coordinate_handle(uv_coordinate) +{} +bool TriangleInversionInvariant::after(PrimitiveType type, const std::vector& t) const +{ + if (type != PrimitiveType::Face) return true; + // assume conterclockwise + ConstAccessor accessor = mesh().create_accessor(m_uv_coordinate_handle); + for (auto& tuple : t) { + Eigen::Vector2d p0 = accessor.const_vector_attribute(tuple); + Eigen::Vector2d p1 = accessor.const_vector_attribute(mesh().switch_vertex(tuple)); + Eigen::Vector2d p2 = + accessor.const_vector_attribute(mesh().switch_vertex(mesh().switch_edge(tuple))); + + if (wmtk::utils::triangle_signed_2d_area(p0, p1, p2) < 0) return false; + } + return true; +} +} // namespace wmtk diff --git a/src/wmtk/invariants/TriangleInversionInvariant.hpp b/src/wmtk/invariants/TriangleInversionInvariant.hpp new file mode 100644 index 0000000000..e6cd384444 --- /dev/null +++ b/src/wmtk/invariants/TriangleInversionInvariant.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include +#include "MeshInvariant.hpp" + +namespace wmtk { +class TriangleInversionInvariant : public MeshInvariant +{ +public: + // NOTE: this takes in the threshold squared rather than the threshold itself + TriangleInversionInvariant(const Mesh& m, const MeshAttributeHandle& uv_coordinate); + using MeshInvariant::MeshInvariant; + bool after(PrimitiveType type, const std::vector& t) const override; + +private: + const MeshAttributeHandle m_uv_coordinate_handle; +}; +} // namespace wmtk diff --git a/src/wmtk/operations/CMakeLists.txt b/src/wmtk/operations/CMakeLists.txt index 69e0862a7c..2a79ea4830 100644 --- a/src/wmtk/operations/CMakeLists.txt +++ b/src/wmtk/operations/CMakeLists.txt @@ -13,3 +13,4 @@ add_subdirectory(edge_mesh) add_subdirectory(tri_mesh) add_subdirectory(tet_mesh) add_subdirectory(utils) + diff --git a/src/wmtk/operations/OperationFactory.hpp b/src/wmtk/operations/OperationFactory.hpp index 33fb5a77f3..a2d024df1e 100644 --- a/src/wmtk/operations/OperationFactory.hpp +++ b/src/wmtk/operations/OperationFactory.hpp @@ -1,4 +1,5 @@ #pragma once +#include #include #include #include "Operation.hpp" @@ -24,12 +25,19 @@ class OperationFactory : public OperationFactoryBase : OperationFactoryBase(OperationType::primitive_type()) , m_settings(settings) {} + OperationFactory(OperationSettings&& settings) + : OperationFactoryBase(OperationType::primitive_type()) + , m_settings(std::move(settings)) + {} std::unique_ptr create(wmtk::Mesh& m, const Tuple& t) const override { + spdlog::info("Using default create"); return std::make_unique(m, t, m_settings); } + const OperationSettings& settings() const { return m_settings; } + protected: const OperationSettings m_settings; }; diff --git a/src/wmtk/operations/TupleOperation.hpp b/src/wmtk/operations/TupleOperation.hpp index 50c651a16c..6b694d1b54 100644 --- a/src/wmtk/operations/TupleOperation.hpp +++ b/src/wmtk/operations/TupleOperation.hpp @@ -23,6 +23,9 @@ class TupleOperation : virtual public Operation // Returns the set of tuples, organized by the type virtual std::vector modified_primitives(PrimitiveType) const; + + const InvariantCollection& invariants() const { return m_invariants; } + private: const InvariantCollection& m_invariants; Tuple m_input_tuple; diff --git a/src/wmtk/operations/tri_mesh/CMakeLists.txt b/src/wmtk/operations/tri_mesh/CMakeLists.txt index 4245f4a8d8..2f05821649 100644 --- a/src/wmtk/operations/tri_mesh/CMakeLists.txt +++ b/src/wmtk/operations/tri_mesh/CMakeLists.txt @@ -15,6 +15,8 @@ set(SRC_FILES EdgeCollapseToMidpoint.cpp EdgeSwap.hpp EdgeSwap.cpp + VertexSmoothUsingDifferentiableEnergy.hpp + VertexSmoothUsingDifferentiableEnergy.cpp EdgeSplitWithTag.hpp EdgeSplitWithTag.cpp VertexAttributesUpdateBase.hpp @@ -27,3 +29,4 @@ set(SRC_FILES FaceSplitAtMidPoint.cpp ) target_sources(wildmeshing_toolkit PRIVATE ${SRC_FILES}) +add_subdirectory(internal) diff --git a/src/wmtk/operations/tri_mesh/EdgeSwap.cpp b/src/wmtk/operations/tri_mesh/EdgeSwap.cpp index a4eaf01239..0684c9a1b0 100644 --- a/src/wmtk/operations/tri_mesh/EdgeSwap.cpp +++ b/src/wmtk/operations/tri_mesh/EdgeSwap.cpp @@ -30,10 +30,10 @@ bool EdgeSwap::before() const long val0 = static_cast(SimplicialComplex::vertex_one_ring(mesh(), v0).size()); long val1 = static_cast(SimplicialComplex::vertex_one_ring(mesh(), v1).size()); if (mesh().is_boundary_vertex(v0)) { - ++val0; + val0 += 2; } if (mesh().is_boundary_vertex(v1)) { - ++val1; + val1 += 2; } if (val0 < 4 || val1 < 4) { return false; diff --git a/src/wmtk/operations/tri_mesh/VertexAttributesUpdateBase.cpp b/src/wmtk/operations/tri_mesh/VertexAttributesUpdateBase.cpp index 93fd19a5af..5252835deb 100644 --- a/src/wmtk/operations/tri_mesh/VertexAttributesUpdateBase.cpp +++ b/src/wmtk/operations/tri_mesh/VertexAttributesUpdateBase.cpp @@ -19,7 +19,10 @@ VertexAttributesUpdateBase::VertexAttributesUpdateBase( : TriMeshOperation(m) , TupleOperation(settings.invariants, t) , m_settings{settings} -{} +{ + assert(m.is_valid_slow(t)); + assert(m.is_valid_slow(input_tuple())); +} std::string VertexAttributesUpdateBase::name() const { @@ -31,6 +34,23 @@ const Tuple& VertexAttributesUpdateBase::return_tuple() const return m_output_tuple; } +std::vector VertexAttributesUpdateBase::modified_primitives(PrimitiveType type) const +{ + if (type == PrimitiveType::Face) { + assert(mesh().is_valid_slow(m_output_tuple)); + Simplex v(PrimitiveType::Vertex, m_output_tuple); + auto sc = SimplicialComplex::open_star(mesh(), v); + auto faces = sc.get_simplices(PrimitiveType::Face); + std::vector ret; + for (const auto& face : faces) { + ret.emplace_back(face.tuple()); + } + return ret; + } else { + return {}; + } +} + bool VertexAttributesUpdateBase::execute() { diff --git a/src/wmtk/operations/tri_mesh/VertexAttributesUpdateBase.hpp b/src/wmtk/operations/tri_mesh/VertexAttributesUpdateBase.hpp index 41f2e22c1a..c0e0b30ff6 100644 --- a/src/wmtk/operations/tri_mesh/VertexAttributesUpdateBase.hpp +++ b/src/wmtk/operations/tri_mesh/VertexAttributesUpdateBase.hpp @@ -2,6 +2,7 @@ #include #include #include + #include "TriMeshOperation.hpp" namespace wmtk::operations { @@ -30,11 +31,12 @@ class VertexAttributesUpdateBase : public TriMeshOperation, protected TupleOpera static PrimitiveType primitive_type() { return PrimitiveType::Vertex; } const Tuple& return_tuple() const; + std::vector modified_primitives(PrimitiveType) const override; protected: bool execute() override; -private: +protected: Tuple m_output_tuple; const OperationSettings& m_settings; }; diff --git a/src/wmtk/operations/tri_mesh/VertexLaplacianSmooth.cpp b/src/wmtk/operations/tri_mesh/VertexLaplacianSmooth.cpp index de6fc742bd..b1826d979b 100644 --- a/src/wmtk/operations/tri_mesh/VertexLaplacianSmooth.cpp +++ b/src/wmtk/operations/tri_mesh/VertexLaplacianSmooth.cpp @@ -1,9 +1,19 @@ #include "VertexLaplacianSmooth.hpp" - #include #include +#include + +namespace wmtk::operations { + +void OperationSettings::initialize_invariants(const TriMesh& m) +{ + base_settings.initialize_invariants(m); + if (!smooth_boundary) { + base_settings.invariants.add(std::make_unique(m)); + } +} // namespace wmtk::operations -namespace wmtk::operations::tri_mesh { +namespace tri_mesh { VertexLaplacianSmooth::VertexLaplacianSmooth( Mesh& m, const Tuple& t, @@ -19,33 +29,18 @@ std::string VertexLaplacianSmooth::name() const } -bool VertexLaplacianSmooth::before() const -{ - if (!mesh().is_valid_slow(input_tuple())) { - return false; - } - if (!m_settings.smooth_boundary && mesh().is_boundary_vertex(input_tuple())) { - return false; - } - return true; -} - bool VertexLaplacianSmooth::execute() { - if (!tri_mesh::VertexAttributesUpdateBase::execute()) { - return false; - } - const Tuple tup = tri_mesh::VertexAttributesUpdateBase::return_tuple(); - const std::vector one_ring = SimplicialComplex::vertex_one_ring(mesh(), tup); - auto p_mid = m_pos_accessor.vector_attribute(tup); + const std::vector one_ring = SimplicialComplex::vertex_one_ring(mesh(), input_tuple()); + auto p_mid = m_pos_accessor.vector_attribute(input_tuple()); p_mid = Eigen::Vector3d::Zero(); for (const Simplex& s : one_ring) { p_mid += m_pos_accessor.vector_attribute(s.tuple()); } p_mid /= one_ring.size(); - return true; + return tri_mesh::VertexAttributesUpdateBase::execute(); } - -} // namespace wmtk::operations::tri_mesh +} // namespace tri_mesh +} // namespace wmtk::operations diff --git a/src/wmtk/operations/tri_mesh/VertexLaplacianSmooth.hpp b/src/wmtk/operations/tri_mesh/VertexLaplacianSmooth.hpp index 5d46df4e07..9aea049103 100644 --- a/src/wmtk/operations/tri_mesh/VertexLaplacianSmooth.hpp +++ b/src/wmtk/operations/tri_mesh/VertexLaplacianSmooth.hpp @@ -16,6 +16,7 @@ struct OperationSettings OperationSettings base_settings; MeshAttributeHandle position; bool smooth_boundary = false; + void initialize_invariants(const TriMesh& m); }; namespace tri_mesh { @@ -31,10 +32,7 @@ class VertexLaplacianSmooth : public VertexAttributesUpdateBase static PrimitiveType primitive_type() { return PrimitiveType::Vertex; } - const Tuple& return_tuple() const; - protected: - bool before() const override; bool execute() override; protected: diff --git a/src/wmtk/operations/tri_mesh/VertexSmoothUsingDifferentiableEnergy.cpp b/src/wmtk/operations/tri_mesh/VertexSmoothUsingDifferentiableEnergy.cpp new file mode 100644 index 0000000000..a7ec176213 --- /dev/null +++ b/src/wmtk/operations/tri_mesh/VertexSmoothUsingDifferentiableEnergy.cpp @@ -0,0 +1,48 @@ +#include "VertexSmoothUsingDifferentiableEnergy.hpp" +#include +#include +#include + +namespace wmtk::operations { +void OperationSettings::initialize_invariants( + const TriMesh& m) +{ + base_settings.initialize_invariants(m); + base_settings.invariants.add( + std::make_shared(m, coordinate_handle)); +} +} // namespace wmtk::operations + +namespace wmtk::operations::tri_mesh { +VertexSmoothUsingDifferentiableEnergy::VertexSmoothUsingDifferentiableEnergy( + Mesh& m, + const Tuple& t, + const OperationSettings& settings) + : VertexAttributesUpdateBase(m, t, settings.base_settings) + , m_settings{settings} +{} + +std::string VertexSmoothUsingDifferentiableEnergy::name() const +{ + return "tri_mesh_vertex_smooth_using_differentiable_energy"; +} + +function::utils::DifferentiableFunctionEvaluator +VertexSmoothUsingDifferentiableEnergy::get_function_evaluator(Accessor& accessor) const +{ + return function::utils::DifferentiableFunctionEvaluator( + *m_settings.energy, + accessor, + simplex::Simplex(PrimitiveType::Vertex, input_tuple())); +} + + +Accessor VertexSmoothUsingDifferentiableEnergy::coordinate_accessor() +{ + return mesh().create_accessor(m_settings.coordinate_handle); +} +ConstAccessor VertexSmoothUsingDifferentiableEnergy::const_coordinate_accessor() const +{ + return mesh().create_const_accessor(m_settings.coordinate_handle); +} +} // namespace wmtk::operations::tri_mesh diff --git a/src/wmtk/operations/tri_mesh/VertexSmoothUsingDifferentiableEnergy.hpp b/src/wmtk/operations/tri_mesh/VertexSmoothUsingDifferentiableEnergy.hpp new file mode 100644 index 0000000000..e3fd948b36 --- /dev/null +++ b/src/wmtk/operations/tri_mesh/VertexSmoothUsingDifferentiableEnergy.hpp @@ -0,0 +1,62 @@ +#pragma once +#include +#include +#include +#include +#include "VertexAttributesUpdateBase.hpp" + +namespace wmtk { +namespace function { +class DifferentiableFunction; + +} +} // namespace wmtk +namespace wmtk::operations { +namespace tri_mesh { +class VertexSmoothUsingDifferentiableEnergy; +} + +template <> +struct OperationSettings +{ + OperationSettings base_settings; + std::unique_ptr energy; + // coordinate for teh attribute used to evaluate the energy + MeshAttributeHandle coordinate_handle; + bool smooth_boundary = false; + + bool second_order = true; + bool line_search = false; + void initialize_invariants(const TriMesh& m); + double step_size = 1.0; +}; + +namespace tri_mesh { +class VertexSmoothUsingDifferentiableEnergy : public VertexAttributesUpdateBase +{ +protected: + VertexSmoothUsingDifferentiableEnergy( + Mesh& m, + const Tuple& t, + const OperationSettings& settings); + +public: + std::string name() const override; + + static PrimitiveType primitive_type() { return PrimitiveType::Vertex; } + + +protected: + function::utils::DifferentiableFunctionEvaluator get_function_evaluator( + Accessor& accessor) const; + MeshAttributeHandle coordinate_handle() const { return m_settings.coordinate_handle; } + + Accessor coordinate_accessor(); + ConstAccessor const_coordinate_accessor() const; + const OperationSettings& m_settings; +}; + +} // namespace tri_mesh +} // namespace wmtk::operations +// provides overload for factory +#include diff --git a/src/wmtk/operations/tri_mesh/VertexTangentialLaplacianSmooth.cpp b/src/wmtk/operations/tri_mesh/VertexTangentialLaplacianSmooth.cpp index 09b9b946d8..a317e2da5c 100644 --- a/src/wmtk/operations/tri_mesh/VertexTangentialLaplacianSmooth.cpp +++ b/src/wmtk/operations/tri_mesh/VertexTangentialLaplacianSmooth.cpp @@ -19,21 +19,10 @@ std::string VertexTangentialLaplacianSmooth::name() const return "tri_mesh_vertex_tangential_smooth"; } -bool VertexTangentialLaplacianSmooth::before() const -{ - if (!mesh().is_valid_slow(input_tuple())) { - return false; - } - return true; -} - bool VertexTangentialLaplacianSmooth::execute() { const Eigen::Vector3d p = m_pos_accessor.vector_attribute(input_tuple()); - if (!tri_mesh::VertexLaplacianSmooth::before()) { - return false; - } if (!tri_mesh::VertexLaplacianSmooth::execute()) { return false; } diff --git a/src/wmtk/operations/tri_mesh/VertexTangentialLaplacianSmooth.hpp b/src/wmtk/operations/tri_mesh/VertexTangentialLaplacianSmooth.hpp index 356108defe..20061dfb1a 100644 --- a/src/wmtk/operations/tri_mesh/VertexTangentialLaplacianSmooth.hpp +++ b/src/wmtk/operations/tri_mesh/VertexTangentialLaplacianSmooth.hpp @@ -33,7 +33,6 @@ class VertexTangentialLaplacianSmooth : public VertexLaplacianSmooth static PrimitiveType primitive_type() { return PrimitiveType::Vertex; } protected: - bool before() const override; bool execute() override; private: diff --git a/src/wmtk/operations/tri_mesh/internal/CMakeLists.txt b/src/wmtk/operations/tri_mesh/internal/CMakeLists.txt new file mode 100644 index 0000000000..14eeb5db0d --- /dev/null +++ b/src/wmtk/operations/tri_mesh/internal/CMakeLists.txt @@ -0,0 +1,12 @@ + +set(SRC_FILES + VertexSmoothUsingDifferentiableEnergyFactory.hpp + VertexSmoothUsingDifferentiableEnergyFactory.cpp + VertexSmoothGradientDescent.hpp + VertexSmoothGradientDescent.cpp + VertexSmoothNewtonMethod.hpp + VertexSmoothNewtonMethod.cpp + VertexSmoothNewtonMethodWithLineSearch.hpp + VertexSmoothNewtonMethodWithLineSearch.cpp +) +target_sources(wildmeshing_toolkit PRIVATE ${SRC_FILES}) diff --git a/src/wmtk/operations/tri_mesh/internal/VertexSmoothGradientDescent.cpp b/src/wmtk/operations/tri_mesh/internal/VertexSmoothGradientDescent.cpp new file mode 100644 index 0000000000..122e532ad7 --- /dev/null +++ b/src/wmtk/operations/tri_mesh/internal/VertexSmoothGradientDescent.cpp @@ -0,0 +1,44 @@ +#include "VertexSmoothGradientDescent.hpp" +#include +namespace wmtk::operations::tri_mesh::internal { +VertexSmoothGradientDescent::VertexSmoothGradientDescent( + Mesh& m, + const Tuple& t, + const OperationSettings& settings) + : VertexSmoothUsingDifferentiableEnergy(m, t, settings) +{} +std::string VertexSmoothGradientDescent::name() const +{ + return "tri_mesh_vertex_smooth_newton_method"; +} + + +Eigen::VectorXd VertexSmoothGradientDescent::get_descent_direction( + function::utils::DifferentiableFunctionEvaluator& f) const +{ + return -f.get_gradient(); +} + +bool VertexSmoothGradientDescent::execute() +{ + auto accessor = coordinate_accessor(); + auto evaluator = get_function_evaluator(accessor); + + auto pos = evaluator.get_coordinate(); + Eigen::Vector2d next_pos = pos + m_settings.step_size * get_descent_direction(evaluator); + evaluator.store(next_pos); + + if (!tri_mesh::VertexSmoothUsingDifferentiableEnergy::execute()) { + wmtk::logger().debug("execute failed"); + return false; + } + return true; +} +std::vector VertexSmoothGradientDescent::priority() const +{ + double gradnorm = m_settings.energy->get_gradient(input_tuple()).norm(); + std::vector r; + r.emplace_back(-gradnorm); + return r; +} +} // namespace wmtk::operations::tri_mesh::internal diff --git a/src/wmtk/operations/tri_mesh/internal/VertexSmoothGradientDescent.hpp b/src/wmtk/operations/tri_mesh/internal/VertexSmoothGradientDescent.hpp new file mode 100644 index 0000000000..34ce7203e2 --- /dev/null +++ b/src/wmtk/operations/tri_mesh/internal/VertexSmoothGradientDescent.hpp @@ -0,0 +1,21 @@ + +#pragma once +#include + +namespace wmtk::operations::tri_mesh::internal { +class VertexSmoothGradientDescent : public VertexSmoothUsingDifferentiableEnergy +{ +public: + VertexSmoothGradientDescent( + Mesh& m, + const Tuple& t, + const OperationSettings& settings); + + std::vector priority() const; + +protected: + Eigen::VectorXd get_descent_direction(function::utils::DifferentiableFunctionEvaluator&) const; + bool execute() override; + std::string name() const; +}; +} // namespace wmtk::operations::tri_mesh::internal diff --git a/src/wmtk/operations/tri_mesh/internal/VertexSmoothNewtonMethod.cpp b/src/wmtk/operations/tri_mesh/internal/VertexSmoothNewtonMethod.cpp new file mode 100644 index 0000000000..ff9fac2cef --- /dev/null +++ b/src/wmtk/operations/tri_mesh/internal/VertexSmoothNewtonMethod.cpp @@ -0,0 +1,62 @@ +#include "VertexSmoothNewtonMethod.hpp" +#include +namespace wmtk::operations::tri_mesh::internal { +VertexSmoothNewtonMethod::VertexSmoothNewtonMethod( + Mesh& m, + const Tuple& t, + const OperationSettings& settings) + : VertexSmoothUsingDifferentiableEnergy(m, t, settings) +{} +std::string VertexSmoothNewtonMethod::name() const +{ + return "tri_mesh_vertex_smooth_newton_method"; +} +Eigen::VectorXd VertexSmoothNewtonMethod::get_descent_direction( + function::utils::DifferentiableFunctionEvaluator& f) const +{ + return -f.get_hessian().ldlt().solve(f.get_gradient()); +} + + +bool VertexSmoothNewtonMethod::execute() +{ + auto accessor = coordinate_accessor(); + auto evaluator = get_function_evaluator(accessor); + + auto pos = evaluator.get_coordinate().eval(); + double value = evaluator.get_value(pos); + auto dir = get_descent_direction(evaluator); + Eigen::Vector2d next_pos = pos + m_settings.step_size * dir; + double new_value = evaluator.get_value(next_pos); + + /* + spdlog::info( + "Went from f({},{})={} to f({},{})={} ====== +={} * {},{}", + pos.x(), + pos.y(), + value, + next_pos.x(), + next_pos.y(), + new_value, + m_settings.step_size, + dir.x(), + dir.y()); + + */ + evaluator.store(next_pos); + + + if (!tri_mesh::VertexSmoothUsingDifferentiableEnergy::execute()) { + return false; + } + return true; +} +std::vector VertexSmoothNewtonMethod::priority() const +{ + double gradnorm = m_settings.energy->get_gradient(input_tuple()).norm(); + std::vector r; + r.emplace_back(-gradnorm); + return r; +} + +} // namespace wmtk::operations::tri_mesh::internal diff --git a/src/wmtk/operations/tri_mesh/internal/VertexSmoothNewtonMethod.hpp b/src/wmtk/operations/tri_mesh/internal/VertexSmoothNewtonMethod.hpp new file mode 100644 index 0000000000..2fea5b2f0c --- /dev/null +++ b/src/wmtk/operations/tri_mesh/internal/VertexSmoothNewtonMethod.hpp @@ -0,0 +1,21 @@ +#pragma once +#include + +namespace wmtk::operations::tri_mesh::internal { + +class VertexSmoothNewtonMethod : public VertexSmoothUsingDifferentiableEnergy +{ +public: + VertexSmoothNewtonMethod( + Mesh& m, + const Tuple& t, + const OperationSettings& settings); + + std::vector priority() const; + +protected: + bool execute() override; + Eigen::VectorXd get_descent_direction(function::utils::DifferentiableFunctionEvaluator&) const; + std::string name() const; +}; +} // namespace wmtk::operations::tri_mesh::internal diff --git a/src/wmtk/operations/tri_mesh/internal/VertexSmoothNewtonMethodWithLineSearch.cpp b/src/wmtk/operations/tri_mesh/internal/VertexSmoothNewtonMethodWithLineSearch.cpp new file mode 100644 index 0000000000..bc0bb36a37 --- /dev/null +++ b/src/wmtk/operations/tri_mesh/internal/VertexSmoothNewtonMethodWithLineSearch.cpp @@ -0,0 +1,38 @@ +#include "VertexSmoothNewtonMethodWithLineSearch.hpp" +#include + +namespace wmtk::operations::tri_mesh::internal { + +VertexSmoothNewtonMethodWithLineSearch::VertexSmoothNewtonMethodWithLineSearch( + Mesh& m, + const Tuple& t, + const OperationSettings& settings) + : VertexSmoothNewtonMethod(m, t, settings) +{} + +bool VertexSmoothNewtonMethodWithLineSearch::execute() +{ + auto accessor = coordinate_accessor(); + auto evaluator = get_function_evaluator(accessor); + + Eigen::Vector2d direction = get_descent_direction(evaluator); + + optimization::LineSearch line_search(evaluator, invariants()); + + line_search.set_create_scope( + false); // since we're in an operation we will fail if the seach doesn't do waht we want + double distance = line_search.run(direction, 1.0); + if (distance == 0.0) { + return false; + } + + + if (!tri_mesh::VertexSmoothUsingDifferentiableEnergy::execute()) { + return false; + } + return true; +} + + +} // namespace wmtk::operations::tri_mesh::internal + diff --git a/src/wmtk/operations/tri_mesh/internal/VertexSmoothNewtonMethodWithLineSearch.hpp b/src/wmtk/operations/tri_mesh/internal/VertexSmoothNewtonMethodWithLineSearch.hpp new file mode 100644 index 0000000000..c409542237 --- /dev/null +++ b/src/wmtk/operations/tri_mesh/internal/VertexSmoothNewtonMethodWithLineSearch.hpp @@ -0,0 +1,15 @@ +#pragma once +#include "VertexSmoothNewtonMethod.hpp" +namespace wmtk::operations::tri_mesh::internal { +class VertexSmoothNewtonMethodWithLineSearch : public VertexSmoothNewtonMethod +{ +public: + VertexSmoothNewtonMethodWithLineSearch( + Mesh& m, + const Tuple& t, + const OperationSettings& settings); + +protected: + bool execute() override; +}; +} // namespace wmtk::operations::tri_mesh::internal diff --git a/src/wmtk/operations/tri_mesh/internal/VertexSmoothUsingDifferentiableEnergyFactory.cpp b/src/wmtk/operations/tri_mesh/internal/VertexSmoothUsingDifferentiableEnergyFactory.cpp new file mode 100644 index 0000000000..6c0a3442f5 --- /dev/null +++ b/src/wmtk/operations/tri_mesh/internal/VertexSmoothUsingDifferentiableEnergyFactory.cpp @@ -0,0 +1,27 @@ +#include "VertexSmoothUsingDifferentiableEnergyFactory.hpp" +#include +#include +#include +#include + + +namespace wmtk::operations { + +template <> + +std::unique_ptr OperationFactory< + tri_mesh::VertexSmoothUsingDifferentiableEnergy>::create(wmtk::Mesh& m, const Tuple& t) const +{ + if (m_settings.second_order) { + if (m_settings.line_search) { + return std::make_unique( + m, + t, + m_settings); + } else { + return std::make_unique(m, t, m_settings); + } + } + return std::make_unique(m, t, m_settings); +} +} // namespace wmtk::operations diff --git a/src/wmtk/operations/tri_mesh/internal/VertexSmoothUsingDifferentiableEnergyFactory.hpp b/src/wmtk/operations/tri_mesh/internal/VertexSmoothUsingDifferentiableEnergyFactory.hpp new file mode 100644 index 0000000000..585c8edc25 --- /dev/null +++ b/src/wmtk/operations/tri_mesh/internal/VertexSmoothUsingDifferentiableEnergyFactory.hpp @@ -0,0 +1,13 @@ +#pragma once +#include +#include + +namespace wmtk::operations { +namespace tri_mesh { +class VertexSmoothUsingDifferentiableEnergy; +} + +template <> +std::unique_ptr OperationFactory< + tri_mesh::VertexSmoothUsingDifferentiableEnergy>::create(wmtk::Mesh& m, const Tuple& t) const; +} // namespace wmtk::operations diff --git a/src/wmtk/optimization/CMakeLists.txt b/src/wmtk/optimization/CMakeLists.txt new file mode 100644 index 0000000000..3529f50732 --- /dev/null +++ b/src/wmtk/optimization/CMakeLists.txt @@ -0,0 +1,5 @@ +set(SRC_FILES + LineSearch.hpp + LineSearch.cpp +) +target_sources(wildmeshing_toolkit PRIVATE ${SRC_FILES}) diff --git a/src/wmtk/optimization/LineSearch.cpp b/src/wmtk/optimization/LineSearch.cpp new file mode 100644 index 0000000000..18dab8ce90 --- /dev/null +++ b/src/wmtk/optimization/LineSearch.cpp @@ -0,0 +1,75 @@ +#include "LineSearch.hpp" +#include +#include + +namespace wmtk::optimization { + +LineSearch::LineSearch( + function::utils::DifferentiableFunctionEvaluator& interface, + const InvariantCollection& invariants) + : m_interface(interface) + , m_invariants(invariants) +{} + +std::vector LineSearch::modified_simplices(PrimitiveType pt) const +{ + return wmtk::simplex::cofaces_single_dimension_tuples( + m_interface.mesh(), + m_interface.simplex(), + pt); + // return m_interface.upper_level_cofaces(); +} + +bool LineSearch::check_state() const +{ + PrimitiveType top_type = m_interface.mesh().top_simplex_type(); + bool before_pass = m_invariants.before(m_interface.tuple()); + bool after_pass = true; + for (const PrimitiveType pt : wmtk::utils::primitive_below(top_type)) { + after_pass |= m_invariants.after(pt, modified_simplices(pt)); + } + return before_pass && after_pass; +} +double LineSearch::run(const Eigen::VectorXd& direction, double step_size) +{ + if (!check_state()) { + return 0; + } + if (m_create_scope) { + { + auto scope = m_interface.mesh().create_scope(); + double retval = _run(direction, step_size); + if (retval == 0) { + scope.mark_failed(); + } + return retval; + } + + } else { + return _run(direction, step_size); + } +} + +double LineSearch::_run(const Eigen::VectorXd& direction, double init_step_size) +{ + int steps = 0; + // just to make sure we try the initial stepsize + double step_size = init_step_size; + double next_step_size = step_size; + double min_step_size = m_min_step_size_ratio * step_size; + Vector current_pos = m_interface.get_const_coordinate(); + Vector new_pos; + do { + new_pos = current_pos + direction * step_size; + m_interface.store(new_pos); + + step_size = next_step_size; + next_step_size /= 2; + } while (steps++ < m_max_steps && step_size > min_step_size && !check_state()); + if (steps == m_max_steps || step_size < min_step_size) { + return 0; + } else { + return step_size; + } +} +} // namespace wmtk::optimization diff --git a/src/wmtk/optimization/LineSearch.hpp b/src/wmtk/optimization/LineSearch.hpp new file mode 100644 index 0000000000..8237fd9799 --- /dev/null +++ b/src/wmtk/optimization/LineSearch.hpp @@ -0,0 +1,49 @@ +#pragma once + +#include +#include + + +namespace wmtk::optimization { + + +class LineSearch +{ +public: + using InvariantCollection = wmtk::InvariantCollection; + LineSearch( + function::utils::DifferentiableFunctionEvaluator& interface, + const InvariantCollection& invariants); + + using Vector = Eigen::VectorXd; + + double run(const Eigen::VectorXd& direction, double step_size); + double _run(const Eigen::VectorXd& direction, double step_size); + + void set_create_scope(bool enable) { m_create_scope = enable; } + void set_max_steps(long max_steps) { m_max_steps = max_steps; } + void set_min_step_size_ratio(long min_step_size_ratio) + { + m_min_step_size_ratio = min_step_size_ratio; + } + +protected: + function::utils::DifferentiableFunctionEvaluator& m_interface; + const InvariantCollection& m_invariants; + + bool m_create_scope = true; + long m_max_steps = 10; + double m_min_step_size_ratio = 1e-6; + + std::vector modified_simplices(PrimitiveType pt) const; + + // TODO: formally define what checking the state means + // we currently make sure that we pass before on the input tuple and after on all top level + // simplices, but should we be passing all of that every time? + bool check_state() const; + +public: +protected: +}; + +} // namespace wmtk::optimization diff --git a/src/wmtk/simplex/CMakeLists.txt b/src/wmtk/simplex/CMakeLists.txt index e7b43883e7..1b43004ec1 100644 --- a/src/wmtk/simplex/CMakeLists.txt +++ b/src/wmtk/simplex/CMakeLists.txt @@ -11,6 +11,10 @@ set(SRC_FILES top_level_cofaces.cpp top_level_cofaces_iterable.hpp top_level_cofaces_iterable.cpp + + cofaces_single_dimension.hpp + cofaces_single_dimension.cpp + link.hpp link.cpp link_iterable.hpp @@ -19,6 +23,8 @@ set(SRC_FILES open_star.cpp open_star_iterable.hpp open_star_iterable.cpp + + faces.hpp faces.cpp faces_iterable.hpp diff --git a/src/wmtk/simplex/SimplexCollection.cpp b/src/wmtk/simplex/SimplexCollection.cpp index 17f2a8fc9c..9830fc178f 100644 --- a/src/wmtk/simplex/SimplexCollection.cpp +++ b/src/wmtk/simplex/SimplexCollection.cpp @@ -31,6 +31,19 @@ void SimplexCollection::add(const SimplexCollection& simplex_collection) m_simplices.insert(m_simplices.end(), s.begin(), s.end()); } +std::vector SimplexCollection::tuple_vector() const +{ + std::vector tuples; + tuples.reserve(m_simplices.size()); // giving the vector some (hopefully) resonable size + + // add simplices to the vector + for (const Simplex& s : m_simplices) { + tuples.emplace_back(s.tuple()); + } + + return tuples; +} + void SimplexCollection::sort_and_clean() { std::sort(m_simplices.begin(), m_simplices.end(), m_simplex_is_less); @@ -38,6 +51,12 @@ void SimplexCollection::sort_and_clean() m_simplices.erase(last, m_simplices.end()); } +void SimplexCollection::sort() +{ + std::sort(m_simplices.begin(), m_simplices.end(), m_simplex_is_less); +} + + bool SimplexCollection::contains(const Simplex& simplex) const { // TODO this is O(n) but can and should be done in O(log n) diff --git a/src/wmtk/simplex/SimplexCollection.hpp b/src/wmtk/simplex/SimplexCollection.hpp index 00fc77e85c..2d9eb0e82c 100644 --- a/src/wmtk/simplex/SimplexCollection.hpp +++ b/src/wmtk/simplex/SimplexCollection.hpp @@ -34,11 +34,17 @@ class SimplexCollection void add(const Simplex& simplex); void add(const SimplexCollection& simplex_collection); - + /** + * @brief return the vector of tuples of the simplex collection. + * + * @return std::vector + */ + std::vector tuple_vector() const; /** * @brief Sort simplex vector and remove duplicates. */ void sort_and_clean(); + void sort(); /** * @brief Check if simplex is contained in collection. diff --git a/src/wmtk/simplex/closed_star.cpp b/src/wmtk/simplex/closed_star.cpp index ec1dada35f..8c8212e3c1 100644 --- a/src/wmtk/simplex/closed_star.cpp +++ b/src/wmtk/simplex/closed_star.cpp @@ -13,7 +13,7 @@ SimplexCollection closed_star(const Mesh& mesh, const Simplex& simplex, const bo SimplexCollection collection(mesh); collection.add(simplex); - + assert(mesh.is_valid_slow(simplex.tuple())); const SimplexCollection top_level_cofaces_collection = mesh.top_simplex_type() == PrimitiveType::Face ? top_level_cofaces(static_cast(mesh), simplex, false) diff --git a/src/wmtk/simplex/cofaces_single_dimension.cpp b/src/wmtk/simplex/cofaces_single_dimension.cpp new file mode 100644 index 0000000000..b68c186d81 --- /dev/null +++ b/src/wmtk/simplex/cofaces_single_dimension.cpp @@ -0,0 +1,59 @@ +#include "cofaces_single_dimension.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include "link.hpp" +#include "top_level_cofaces.hpp" +namespace wmtk::simplex { + +std::vector cofaces_single_dimension_tuples( + const Mesh& mesh, + const Simplex& simplex, + PrimitiveType cofaces_type) +{ + switch (mesh.top_simplex_type()) { + case PrimitiveType::Face: + return cofaces_single_dimension_tuples(static_cast(mesh), simplex, cofaces_type); + case PrimitiveType::Tetrahedron: + //return cofaces_single_dimension_tuples(static_cast(mesh), simplex, cofaces_type); + case PrimitiveType::Edge: + case PrimitiveType::Vertex: + default: throw std::runtime_error("unknown mesh type in cofaces_single_dimension_tuples"); + } + return {}; +} + +std::vector cofaces_single_dimension_tuples( + const TriMesh& mesh, + const Simplex& my_simplex, + PrimitiveType cofaces_type) +{ + assert(my_simplex.primitive_type() < cofaces_type); + std::vector collection; + if (my_simplex.primitive_type() == PrimitiveType::Vertex && + (cofaces_type == PrimitiveType::Edge)) { + auto sc = link(mesh, my_simplex); + std::vector coface_edge_tuples; + for (const Simplex& edge : sc.simplex_vector(PrimitiveType::Edge)) { + coface_edge_tuples.emplace_back(mesh.switch_vertex(mesh.switch_edge(edge.tuple()))); + coface_edge_tuples.emplace_back( + mesh.switch_vertex(mesh.switch_edge(mesh.switch_vertex(edge.tuple())))); + } + SimplexCollection ec( + mesh, + utils::tuple_vector_to_homogeneous_simplex_vector( + coface_edge_tuples, + PrimitiveType::Edge)); + ec.sort_and_clean(); + collection = ec.tuple_vector(); + } else { + collection = top_level_cofaces_tuples(mesh, my_simplex); + } + return collection; +} +} // namespace wmtk::simplex diff --git a/src/wmtk/simplex/cofaces_single_dimension.hpp b/src/wmtk/simplex/cofaces_single_dimension.hpp new file mode 100644 index 0000000000..2b6e54ce8f --- /dev/null +++ b/src/wmtk/simplex/cofaces_single_dimension.hpp @@ -0,0 +1,20 @@ +#pragma once +#include +#include +#include +#include + +namespace wmtk::simplex { + + // Returns the cofaces of a provided simplex, but only providing the cofaces in the provided coface type + +std::vector cofaces_single_dimension_tuples( + const TriMesh& mesh, + const Simplex& my_simplex, + PrimitiveType cofaces_type); + +std::vector cofaces_single_dimension_tuples( + const Mesh& mesh, + const Simplex& my_simplex, + PrimitiveType cofaces_type); +} // namespace wmtk::simplex diff --git a/src/wmtk/simplex/top_level_cofaces.cpp b/src/wmtk/simplex/top_level_cofaces.cpp index 6e3e406fcd..8076ab1334 100644 --- a/src/wmtk/simplex/top_level_cofaces.cpp +++ b/src/wmtk/simplex/top_level_cofaces.cpp @@ -19,6 +19,7 @@ std::vector top_level_cofaces_tuples_vertex(const TriMesh& mesh, const Tu { std::vector collection; + assert(mesh.is_valid_slow(t)); std::set touched_cells; std::queue q; q.push(t); diff --git a/src/wmtk/utils/CMakeLists.txt b/src/wmtk/utils/CMakeLists.txt index f1521d163b..7c04dbe0cb 100644 --- a/src/wmtk/utils/CMakeLists.txt +++ b/src/wmtk/utils/CMakeLists.txt @@ -14,7 +14,13 @@ set(SRC_FILES mesh_utils.hpp mesh_utils.cpp TupleInspector.hpp + TupleInspector.cpp + triangle_areas.hpp + triangle_areas.cpp + + #Optimization.hpp + #Optimization.cpp metaprogramming/as_mesh_variant.hpp metaprogramming/as_mesh_variant.cpp @@ -26,6 +32,9 @@ set(SRC_FILES metaprogramming/ReferenceWrapperVariant.hpp metaprogramming/as_variant.hpp metaprogramming/unwrap_ref.hpp + + primitive_range.hpp + primitive_range.cpp ) target_sources(wildmeshing_toolkit PRIVATE ${SRC_FILES}) diff --git a/src/wmtk/utils/primitive_range.cpp b/src/wmtk/utils/primitive_range.cpp new file mode 100644 index 0000000000..7b2c6d34af --- /dev/null +++ b/src/wmtk/utils/primitive_range.cpp @@ -0,0 +1,64 @@ +#include "primitive_range.hpp" +namespace wmtk::utils { +std::vector primitive_range(PrimitiveType pt0, PrimitiveType pt1) +{ + std::vector r; + switch (pt0) { + case PrimitiveType::Vertex: + r.emplace_back(PrimitiveType::Vertex); + if (pt1 == r.back()) { + break; + } + [[fallthrough]]; + case PrimitiveType::Edge: + r.emplace_back(PrimitiveType::Edge); + if (pt1 == r.back()) { + break; + } + [[fallthrough]]; + case PrimitiveType::Face: + r.emplace_back(PrimitiveType::Face); + if (pt1 == r.back()) { + break; + } + [[fallthrough]]; + case PrimitiveType::Tetrahedron: + r.emplace_back(PrimitiveType::Tetrahedron); + if (pt1 == r.back()) { + break; + } + [[fallthrough]]; + case PrimitiveType::HalfEdge: + default: break; + } + return r; +} +std::vector primitive_above(PrimitiveType pt) +{ + std::vector r; + + switch (pt) { + case PrimitiveType::Vertex: r.emplace_back(PrimitiveType::Vertex); [[fallthrough]]; + case PrimitiveType::Edge: r.emplace_back(PrimitiveType::Edge); [[fallthrough]]; + case PrimitiveType::Face: r.emplace_back(PrimitiveType::Face); [[fallthrough]]; + case PrimitiveType::Tetrahedron: r.emplace_back(PrimitiveType::Tetrahedron); [[fallthrough]]; + case PrimitiveType::HalfEdge: + default: break; + } + return r; +} +std::vector primitive_below(PrimitiveType pt) +{ + std::vector r; + + switch (pt) { + case PrimitiveType::Tetrahedron: r.emplace_back(PrimitiveType::Tetrahedron); [[fallthrough]]; + case PrimitiveType::Face: r.emplace_back(PrimitiveType::Face); [[fallthrough]]; + case PrimitiveType::Edge: r.emplace_back(PrimitiveType::Edge); [[fallthrough]]; + case PrimitiveType::Vertex: r.emplace_back(PrimitiveType::Vertex); [[fallthrough]]; + case PrimitiveType::HalfEdge: + default: break; + } + return r; +} +} // namespace wmtk::utils diff --git a/src/wmtk/utils/primitive_range.hpp b/src/wmtk/utils/primitive_range.hpp new file mode 100644 index 0000000000..ac07f6e6b7 --- /dev/null +++ b/src/wmtk/utils/primitive_range.hpp @@ -0,0 +1,12 @@ +#pragma once + +#include +#include +namespace wmtk::utils { +// returns a vector of primitives including the endpoitns of the range +std::vector primitive_range(PrimitiveType pt0, PrimitiveType pt1); +// returns a vector of primitives including the endpoint +std::vector primitive_above(PrimitiveType pt0); +// returns a vector of primitives including the endpoint +std::vector primitive_below(PrimitiveType pt1); +} // namespace wmtk::utils diff --git a/src/wmtk/utils/triangle_areas.cpp b/src/wmtk/utils/triangle_areas.cpp new file mode 100644 index 0000000000..491135a446 --- /dev/null +++ b/src/wmtk/utils/triangle_areas.cpp @@ -0,0 +1,6 @@ +#include "triangle_areas.hpp" +#include +#include + +namespace wmtk::utils { +} diff --git a/src/wmtk/utils/triangle_areas.hpp b/src/wmtk/utils/triangle_areas.hpp new file mode 100644 index 0000000000..f9740dc472 --- /dev/null +++ b/src/wmtk/utils/triangle_areas.hpp @@ -0,0 +1,62 @@ +#pragma once +#include +#include +namespace wmtk { + class Tuple; + class TriMesh; + namespace attribute { + template + class MeshAttributeHandle; + } +} +namespace wmtk::utils { + +// template get 3d tri area +template +auto triangle_3d_area( + const Eigen::MatrixBase& a, + const Eigen::MatrixBase& b, + const Eigen::MatrixBase& c) -> typename ADerived::Scalar +{ + + auto ba = b-a; + auto ca = c-a; + return typename ADerived::Scalar(.5) * ba.cross(ca).norm(); + +} + +// template get 3d tri area +template +auto triangle_signed_2d_area( + const Eigen::MatrixBase& a, + const Eigen::MatrixBase& b, + const Eigen::MatrixBase& c) -> typename ADerived::Scalar +{ + auto ba = (b - a).eval(); + auto ca = (c - a).eval(); + return typename ADerived::Scalar(.5) * ba.homogeneous().cross(ca.homogeneous()).z(); +} + +template +auto triangle_unsigned_2d_area( + const Eigen::MatrixBase& a, + const Eigen::MatrixBase& b, + const Eigen::MatrixBase& c) -> typename ADerived::Scalar +{ + return std::abs(triangle_signed_2d_area(a,b,c)); +} + +template +bool triangle_2d_orientation( + const Eigen::MatrixBase& a, + const Eigen::MatrixBase& b, + const Eigen::MatrixBase& c) +{ + auto res = igl::predicates::orient2d(a, b, c); + if (res == igl::predicates::Orientation::POSITIVE) + return true; + else + return false; +} + +} // namespace wmtk diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 3f9fe649e1..4af9d758c1 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -30,8 +30,10 @@ set(TEST_SOURCES test_multi_mesh_visitor.cpp test_mesh_variant.cpp test_variant_metaprogramming.cpp + test_autodiff.cpp test_tuple_metaprogramming.cpp test_local_switch_search.cpp + test_primitive.cpp tools/all_valid_local_tuples.hpp @@ -75,6 +77,10 @@ target_link_libraries(wmtk_tests PUBLIC add_subdirectory(components) source_group("components" REGULAR_EXPRESSION "components\/.*\.(cpp|h|hpp)?$") + +add_subdirectory(function) +source_group("function" REGULAR_EXPRESSION "function\/.*\.(cpp|h|hpp)?$") + wmtk_copy_dll(wmtk_tests) # Register unit tests diff --git a/tests/components/CMakeLists.txt b/tests/components/CMakeLists.txt index dafebe9ec1..ac052c8ac8 100644 --- a/tests/components/CMakeLists.txt +++ b/tests/components/CMakeLists.txt @@ -5,10 +5,11 @@ set(TEST_SOURCES test_component_mesh_info.cpp test_component_output.cpp test_component_isotropic_remeshing.cpp + #test_smoothing.cpp test_component_regular_space.cpp ) target_sources(wmtk_tests PRIVATE ${TEST_SOURCES}) target_link_libraries(wmtk_tests PUBLIC wildmeshing_components -) \ No newline at end of file +) diff --git a/tests/components/test_component_isotropic_remeshing.cpp b/tests/components/test_component_isotropic_remeshing.cpp index 3ee561e187..6d96cb88d7 100644 --- a/tests/components/test_component_isotropic_remeshing.cpp +++ b/tests/components/test_component_isotropic_remeshing.cpp @@ -49,6 +49,7 @@ TEST_CASE("smoothing_bunny", "[components][isotropic_remeshing][2D]") OperationSettings op_settings; op_settings.position = mesh.get_attribute_handle("position", PrimitiveType::Vertex); + op_settings.initialize_invariants(mesh); Scheduler scheduler(mesh); scheduler.add_operation_type("vertex_smooth", op_settings); @@ -75,6 +76,7 @@ TEST_CASE("smoothing_simple_examples", "[components][isotropic_remeshing][2D]") OperationSettings op_settings; op_settings.position = mesh.get_attribute_handle("position", PrimitiveType::Vertex); + op_settings.initialize_invariants(mesh); // offset interior vertex auto pos = mesh.create_accessor(op_settings.position); @@ -97,6 +99,7 @@ TEST_CASE("smoothing_simple_examples", "[components][isotropic_remeshing][2D]") OperationSettings op_settings; op_settings.position = mesh.get_attribute_handle("position", PrimitiveType::Vertex); + op_settings.initialize_invariants(mesh); // offset interior vertex auto pos = mesh.create_accessor(op_settings.position); @@ -133,6 +136,7 @@ TEST_CASE("tangential_smoothing", "[components][isotropic_remeshing][2D]") OperationSettings op_settings; op_settings.smooth_settings.position = mesh.get_attribute_handle("position", PrimitiveType::Vertex); + op_settings.smooth_settings.initialize_invariants(mesh); // offset interior vertex auto pos = mesh.create_accessor(op_settings.smooth_settings.position); @@ -163,7 +167,9 @@ TEST_CASE("tangential_smoothing", "[components][isotropic_remeshing][2D]") v4 = mesh.tuple_from_id(PrimitiveType::Vertex, 4); Eigen::Vector3d after_smooth = pos.vector_attribute(v4); - CHECK((after_smooth - Eigen::Vector3d{1, 0, p_init[2]}).squaredNorm() == 0); + Eigen::Vector3d target = Eigen::Vector3d{1, 0, p_init[2]}; + std::cout << after_smooth.transpose() << " == " << target.transpose() << std::endl; + CHECK((after_smooth - target).squaredNorm() == 0); } TEST_CASE("tangential_smoothing_boundary", "[components][isotropic_remeshing][2D]") @@ -178,6 +184,8 @@ TEST_CASE("tangential_smoothing_boundary", "[components][isotropic_remeshing][2D mesh.get_attribute_handle("position", PrimitiveType::Vertex); op_settings.smooth_settings.smooth_boundary = true; + op_settings.smooth_settings.initialize_invariants(mesh); + // offset interior vertex auto pos = mesh.create_accessor(op_settings.smooth_settings.position); Tuple v1 = mesh.tuple_from_id(PrimitiveType::Vertex, 1); @@ -547,6 +555,7 @@ TEST_CASE("swap_edge_for_valence", "[components][isotropic_remeshing][swap][2D]" { OperationSettings op_settings; op_settings.must_improve_valence = true; + // op_settings.initialize_invariants(mesh); const Tuple e = mesh.edge_tuple_between_v1_v2(6, 7, 5); EdgeSwap op(mesh, e, op_settings); const bool success = op(); diff --git a/tests/components/test_smoothing.cpp b/tests/components/test_smoothing.cpp new file mode 100644 index 0000000000..03609c009e --- /dev/null +++ b/tests/components/test_smoothing.cpp @@ -0,0 +1,85 @@ +#include +#include +#include +#include +#include +#include "../tools/DEBUG_TriMesh.hpp" +#include "../tools/TriMesh_examples.hpp" +using namespace wmtk; +using namespace wmtk::tests; +using namespace wmtk::operations; + +TEST_CASE("smoothing_Newton_Method") +{ + DEBUG_TriMesh mesh = ten_triangles_with_position(2); + OperationSettings op_settings; + op_settings.coordinate_handle = + mesh.get_attribute_handle("position", PrimitiveType::Vertex); + op_settings.smooth_boundary = false; + op_settings.second_order = true; + op_settings.line_search = false; + op_settings.step_size = 0.1; + op_settings.energy = std::make_unique(mesh, op_settings.coordinate_handle); + op_settings.initialize_invariants(mesh); + + + spdlog::info("HJELL?"); + Scheduler scheduler(mesh); + const auto& factory = + scheduler.add_operation_type( + "optimize_vertices", + std::move(op_settings)); + Tuple tuple = mesh.face_tuple_from_vids(2, 4, 5); + spdlog::warn("Initial valuenorm: {}", factory.settings().energy->get_one_ring_value(tuple)); + spdlog::warn( + "Initial gradient: norm: {}", + factory.settings().energy->get_local_gradient(tuple).norm()); + while (factory.settings().energy->get_local_gradient(tuple).norm() > 1e-10) { + scheduler.run_operation_on_all(PrimitiveType::Vertex, "optimize_vertices"); + REQUIRE(scheduler.number_of_successful_operations() > 0); + tuple = mesh.face_tuple_from_vids(2, 4, 5); + } + ConstAccessor pos = mesh.create_const_accessor(op_settings.coordinate_handle); + + Eigen::Vector2d uv0 = pos.const_vector_attribute(tuple); + Eigen::Vector2d uv1 = pos.const_vector_attribute(mesh.switch_vertex(tuple)); + Eigen::Vector2d uv2 = pos.const_vector_attribute(mesh.switch_vertex(mesh.switch_edge(tuple))); + + CHECK((uv0 - uv1).norm() - (uv1 - uv2).norm() < 1e-6); + CHECK((uv0 - uv1).norm() - (uv0 - uv2).norm() < 1e-6); + CHECK((uv1 - uv2).norm() - (uv0 - uv2).norm() < 1e-6); +} + +TEST_CASE("smoothing_Newton_Method_line_search") +{ + DEBUG_TriMesh mesh = ten_triangles_with_position(2); + OperationSettings op_settings; + op_settings.coordinate_handle = + mesh.get_attribute_handle("position", PrimitiveType::Vertex); + op_settings.smooth_boundary = false; + op_settings.second_order = true; + op_settings.line_search = true; + op_settings.energy = std::make_unique(mesh, op_settings.coordinate_handle); + op_settings.initialize_invariants(mesh); + Scheduler scheduler(mesh); + const auto& factory = + scheduler.add_operation_type( + "optimize_vertices", + std::move(op_settings)); + Tuple tuple = mesh.face_tuple_from_vids(2, 4, 5); + spdlog::warn( + "Initial gradient: norm: {}", + factory.settings().energy->get_one_ring_gradient(tuple).norm()); + while (factory.settings().energy->get_one_ring_gradient(tuple).norm() > 1e-10) { + scheduler.run_operation_on_all(PrimitiveType::Vertex, "optimize_vertices"); + tuple = mesh.face_tuple_from_vids(2, 4, 5); + } + ConstAccessor pos = mesh.create_const_accessor(op_settings.coordinate_handle); + + Eigen::Vector2d uv0 = pos.const_vector_attribute(tuple); + Eigen::Vector2d uv1 = pos.const_vector_attribute(mesh.switch_vertex(tuple)); + Eigen::Vector2d uv2 = pos.const_vector_attribute(mesh.switch_vertex(mesh.switch_edge(tuple))); + CHECK((uv0 - uv1).norm() - (uv1 - uv2).norm() < 1e-6); + CHECK((uv0 - uv1).norm() - (uv0 - uv2).norm() < 1e-6); + CHECK((uv1 - uv2).norm() - (uv0 - uv2).norm() < 1e-6); +} diff --git a/tests/function/CMakeLists.txt b/tests/function/CMakeLists.txt new file mode 100644 index 0000000000..216eb0e86a --- /dev/null +++ b/tests/function/CMakeLists.txt @@ -0,0 +1,7 @@ +# Sources +set(TEST_SOURCES + test_2d_energy.cpp + test_amips.cpp +) +target_sources(wmtk_tests PRIVATE ${TEST_SOURCES}) + diff --git a/tests/function/test_2d_energy.cpp b/tests/function/test_2d_energy.cpp new file mode 100644 index 0000000000..4c0ac259a1 --- /dev/null +++ b/tests/function/test_2d_energy.cpp @@ -0,0 +1,115 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "../tools/DEBUG_TriMesh.hpp" +#include "../tools/TriMesh_examples.hpp" +using namespace wmtk; +using namespace wmtk::function; +using namespace wmtk::tests; +using namespace wmtk::simplex; +TEST_CASE("energy_valence") +{ + // 0---1---2 + // /0\1/2\3/4\ . + // 3---4---5---6 + // \5/6\7/ . + // 7---8 + const DEBUG_TriMesh example_mesh = hex_plus_two_with_position(); + auto e1 = example_mesh.edge_tuple_between_v1_v2(3, 4, 0); + auto e2 = example_mesh.edge_tuple_between_v1_v2(4, 0, 1); + auto e3 = example_mesh.edge_tuple_between_v1_v2(4, 5, 2); + auto e4 = example_mesh.edge_tuple_between_v1_v2(5, 1, 3); + + + const TriMesh tri_mesh = static_cast(example_mesh); + + + ValenceEnergyPerEdge valence_energy(tri_mesh); + + + REQUIRE(valence_energy.get_value(Simplex(PrimitiveType::Edge, e1)) == 2); + REQUIRE(valence_energy.get_value(Simplex(PrimitiveType::Edge, e2)) == 2); + REQUIRE(valence_energy.get_value(Simplex(PrimitiveType::Edge, e3)) == 2); + REQUIRE(valence_energy.get_value(Simplex(PrimitiveType::Edge, e4)) == 2); +} + +TEST_CASE("amips2d_values") +{ + SECTION("equilateral_triangle") + { + const DEBUG_TriMesh example_mesh = single_equilateral_triangle(2); + + auto uv_handle = + example_mesh.get_attribute_handle("position", PrimitiveType::Vertex); + auto e1 = example_mesh.edge_tuple_between_v1_v2(0, 1, 0); + const TriMesh tri_mesh = static_cast(example_mesh); + + AMIPS2D amips2d(tri_mesh, uv_handle); + + CHECK(amips2d.get_value(Simplex(PrimitiveType::Face, e1)) == 2.0); + } + SECTION("random_triangle") + { + for (int i = 0; i < 50; i++) { + const DEBUG_TriMesh example_mesh = single_2d_triangle_with_random_positions(123); + + auto uv_handle = + example_mesh.get_attribute_handle("position", PrimitiveType::Vertex); + auto e1 = example_mesh.edge_tuple_between_v1_v2(0, 1, 0); + const TriMesh tri_mesh = static_cast(example_mesh); + + AMIPS2D amips2d(tri_mesh, uv_handle); + CHECK(amips2d.get_value(Simplex(PrimitiveType::Face, e1)) >= 2.); + } + } +} + +TEST_CASE("PositionMapAMIPS_values") +{ + SECTION("equilateral_triangle") + { + const DEBUG_TriMesh example_mesh = single_equilateral_triangle(2); + + auto e1 = example_mesh.edge_tuple_between_v1_v2(0, 1, 0); + const TriMesh tri_mesh = static_cast(example_mesh); + auto uv_handle = + example_mesh.get_attribute_handle("position", PrimitiveType::Vertex); + + PositionMapAMIPS2D amips3d( + tri_mesh, + uv_handle, + wmtk::image::SamplingAnalyticFunction::FunctionType::Linear, + 0.0, + 0.0, + 1.0); + + CHECK(amips3d.get_value(Simplex(PrimitiveType::Face, e1)) == 2.0); + } + SECTION("random_triangle") + { + for (int i = 0; i < 50; i++) { + const DEBUG_TriMesh example_mesh = single_2d_triangle_with_random_positions(123); + + auto uv_handle = + example_mesh.get_attribute_handle("position", PrimitiveType::Vertex); + auto e1 = example_mesh.edge_tuple_between_v1_v2(0, 1, 0); + const TriMesh tri_mesh = static_cast(example_mesh); + + PositionMapAMIPS2D amips3d( + tri_mesh, + uv_handle, + wmtk::image::SamplingAnalyticFunction::FunctionType::Linear, + 0.0, + 0.0, + 1.0); + + CHECK(amips3d.get_value(Simplex(PrimitiveType::Face, e1)) >= 2.0); + } + } +} diff --git a/tests/function/test_amips.cpp b/tests/function/test_amips.cpp new file mode 100644 index 0000000000..dd23789d5d --- /dev/null +++ b/tests/function/test_amips.cpp @@ -0,0 +1,30 @@ +#include +#include +#include + +TEST_CASE("amips2d") +{ + SECTION("equilateral_triangle") + { + Eigen::Vector2d uv0 = {0,0}; + Eigen::Vector2d uv1 = {1,0}; + Eigen::Vector2d uv2 = {0.5,std::sqrt(3)/2}; + + + std::cout << uv0 << std::endl; + + CHECK(wmtk::function::utils::amips(uv0,uv1,uv2) == 2.0); + } +} + +TEST_CASE("amips3d") +{ + SECTION("equilateral_triangle") + { + Eigen::Vector2d uv0 = {0,0}; + Eigen::Vector2d uv1 = {1,0}; + Eigen::Vector2d uv2 = {0.5,std::sqrt(3)/2}; + + CHECK(wmtk::function::utils::amips(uv0,uv1,uv2) == 2.0); + } +} diff --git a/tests/test_2d_operations.cpp b/tests/test_2d_operations.cpp index 5c423fc7c6..0d58999057 100644 --- a/tests/test_2d_operations.cpp +++ b/tests/test_2d_operations.cpp @@ -1318,7 +1318,7 @@ TEST_CASE("split_face", "[operations][split][2D]") // V.row(0) << 0, 0, 0; // V.row(1) << 1, 0, 0; // V.row(2) << 0.5, 0.866, 0; - DEBUG_TriMesh m = single_triangle_with_position(); + DEBUG_TriMesh m = single_equilateral_triangle(3); Tuple f = m.edge_tuple_between_v1_v2(1, 2, 0); OperationSettings settings; settings.initialize_invariants(m); @@ -1387,7 +1387,7 @@ TEST_CASE("split_face", "[operations][split][2D]") // V.row(0) << 0, 0, 0; // V.row(1) << 1, 0, 0; // V.row(2) << 0.5, 0.866, 0; - DEBUG_TriMesh m = single_triangle_with_position(); + DEBUG_TriMesh m = single_equilateral_triangle(3); Tuple f = m.edge_tuple_between_v1_v2(1, 2, 0); MeshAttributeHandle pos_handle = m.get_attribute_handle("position", PV); MeshAttributeHandle todo_handle = m.register_attribute("todo_face", PF, 1); @@ -1422,7 +1422,7 @@ TEST_CASE("split_face", "[operations][split][2D]") } SECTION("should fail with todo tag 0") { - DEBUG_TriMesh m = single_triangle_with_position(); + DEBUG_TriMesh m = single_equilateral_triangle(3); Tuple f = m.edge_tuple_between_v1_v2(1, 2, 0); MeshAttributeHandle pos_handle = m.get_attribute_handle("position", PV); MeshAttributeHandle todo_handle = m.register_attribute("todo_face", PF, 1); @@ -1585,4 +1585,4 @@ TEST_CASE("split_edge_operation_with_tag", "[operations][split][2D]") } CHECK(success_num == 2); } -} \ No newline at end of file +} diff --git a/tests/test_autodiff.cpp b/tests/test_autodiff.cpp new file mode 100644 index 0000000000..8af524730a --- /dev/null +++ b/tests/test_autodiff.cpp @@ -0,0 +1,23 @@ + +#include +#include +#include +#include +#include + + +TEST_CASE("analytic_autodiff", "[autodiff]") +{ + auto raii = wmtk::function::utils::AutoDiffRAII(2); + REQUIRE(DiffScalarBase::getVariableCount() == 2); + Eigen::Vector2d x{2, 3}; + using DScalar = DScalar2; + auto xD = wmtk::function::utils::as_DScalar(x); + + auto v = xD.x() * xD.y() * xD.y(); + + CHECK(v.getValue() == 18); + + Eigen::Vector2d grad{9., 12.}; + CHECK(v.getGradient() == grad); +} diff --git a/tests/test_execution.cpp b/tests/test_execution.cpp index 79fd7997fc..2ee0287bac 100644 --- a/tests/test_execution.cpp +++ b/tests/test_execution.cpp @@ -54,7 +54,7 @@ TEST_CASE("operation_with_settings", "[scheduler][operations][2D]") operations::OperationSettings op_settings; op_settings.position = m.get_attribute_handle("position", PrimitiveType::Vertex); - op_settings.base_settings.initialize_invariants(m); + op_settings.initialize_invariants(m); Scheduler scheduler(m); scheduler.add_operation_type("vertex_smooth", op_settings); @@ -74,14 +74,14 @@ TEST_CASE("scheduler_success_report", "[scheduler][operations][2D]") operations::OperationSettings op_settings; SECTION("single_triangle_with_boundary") { - m = single_triangle_with_position(); + m = single_equilateral_triangle(); expected_op_success = 1; expected_op_fail = 2; op_settings.smooth_boundary = true; } SECTION("single_triangle_without_boundary") { - m = single_triangle_with_position(); + m = single_equilateral_triangle(); expected_op_success = 0; expected_op_fail = 3; op_settings.smooth_boundary = false; @@ -96,7 +96,7 @@ TEST_CASE("scheduler_success_report", "[scheduler][operations][2D]") const long expected_op_sum = expected_op_success + expected_op_fail; op_settings.position = m.get_attribute_handle("position", PrimitiveType::Vertex); - op_settings.base_settings.initialize_invariants(m); + op_settings.initialize_invariants(m); Scheduler scheduler(m); scheduler.add_operation_type("vertex_smooth", op_settings); @@ -109,11 +109,11 @@ TEST_CASE("scheduler_success_report", "[scheduler][operations][2D]") } SECTION("multiple_runs") { - DEBUG_TriMesh m = single_triangle_with_position(); + DEBUG_TriMesh m = single_equilateral_triangle(); operations::OperationSettings op_settings; op_settings.smooth_boundary = true; op_settings.position = m.get_attribute_handle("position", PrimitiveType::Vertex); - op_settings.base_settings.initialize_invariants(m); + op_settings.initialize_invariants(m); Scheduler scheduler(m); scheduler.add_operation_type("vertex_smooth", op_settings); @@ -125,4 +125,4 @@ TEST_CASE("scheduler_success_report", "[scheduler][operations][2D]") CHECK(scheduler.number_of_failed_operations() == 2); } } -} \ No newline at end of file +} diff --git a/tests/test_io.cpp b/tests/test_io.cpp index 0b59fad1f4..9582f89e6a 100644 --- a/tests/test_io.cpp +++ b/tests/test_io.cpp @@ -99,7 +99,7 @@ TEST_CASE("paraview_3d", "[io]") TEST_CASE("attribute_after_split", "[io]") { - DEBUG_TriMesh m = single_triangle_with_position(); + DEBUG_TriMesh m = single_equilateral_triangle(); wmtk::MeshAttributeHandle attribute_handle = m.register_attribute(std::string("test_attribute"), PE, 1); wmtk::MeshAttributeHandle pos_handle = diff --git a/tests/test_primitive.cpp b/tests/test_primitive.cpp new file mode 100644 index 0000000000..eb464f0c47 --- /dev/null +++ b/tests/test_primitive.cpp @@ -0,0 +1,109 @@ +#include +#include +#include + +using namespace wmtk; +TEST_CASE("primitive_range", "[primitive]") +{ + for (PrimitiveType pt : + {PrimitiveType::Vertex, + PrimitiveType::Edge, + PrimitiveType::Face, + PrimitiveType::Tetrahedron}) { + { + auto a = wmtk::utils::primitive_range(pt, PrimitiveType::Tetrahedron); + auto b = wmtk::utils::primitive_above(pt); + CHECK(a == b); + } + { + auto a = wmtk::utils::primitive_range(PrimitiveType::Vertex, pt); + auto b = wmtk::utils::primitive_below(pt); + std::reverse(b.begin(), b.end()); + CHECK(a == b); + } + } + // 1,1 + // 1,2 + // 2,2 + { + auto a = wmtk::utils::primitive_range(PrimitiveType::Edge, PrimitiveType::Edge); + std::vector b{PrimitiveType::Edge}; + CHECK(a == b); + } + { + auto a = wmtk::utils::primitive_range(PrimitiveType::Face, PrimitiveType::Face); + std::vector b{PrimitiveType::Face}; + CHECK(a == b); + } + { + auto a = wmtk::utils::primitive_range(PrimitiveType::Edge, PrimitiveType::Face); + std::vector b{PrimitiveType::Edge, PrimitiveType::Face}; + CHECK(a == b); + } +} +TEST_CASE("primitive_above", "[primitive]") +{ + { + auto a = wmtk::utils::primitive_above(PrimitiveType::Tetrahedron); + std::vector b{PrimitiveType::Tetrahedron}; + CHECK(a == b); + } + { + auto a = wmtk::utils::primitive_above(PrimitiveType::Face); + std::vector b{PrimitiveType::Face, PrimitiveType::Tetrahedron}; + CHECK(a == b); + } + { + auto a = wmtk::utils::primitive_above(PrimitiveType::Edge); + std::vector b{ + PrimitiveType::Edge, + PrimitiveType::Face, + PrimitiveType::Tetrahedron, + }; + CHECK(a == b); + } + { + auto a = wmtk::utils::primitive_above(PrimitiveType::Vertex); + std::vector b{ + PrimitiveType::Vertex, + PrimitiveType::Edge, + PrimitiveType::Face, + PrimitiveType::Tetrahedron}; + CHECK(a == b); + } +} +TEST_CASE("primitive_below", "[primitive]") +{ + { + auto a = wmtk::utils::primitive_below(PrimitiveType::Tetrahedron); + std::vector b{ + PrimitiveType::Tetrahedron, + PrimitiveType::Face, + PrimitiveType::Edge, + PrimitiveType::Vertex, + }; + CHECK(a == b); + } + { + auto a = wmtk::utils::primitive_below(PrimitiveType::Face); + std::vector b{ + PrimitiveType::Face, + PrimitiveType::Edge, + PrimitiveType::Vertex, + }; + CHECK(a == b); + } + { + auto a = wmtk::utils::primitive_below(PrimitiveType::Edge); + std::vector b{ + PrimitiveType::Edge, + PrimitiveType::Vertex, + }; + CHECK(a == b); + } + { + auto a = wmtk::utils::primitive_below(PrimitiveType::Vertex); + std::vector b{PrimitiveType::Vertex}; + CHECK(a == b); + } +} diff --git a/tests/test_simplex_collection.cpp b/tests/test_simplex_collection.cpp index e3ae4774ba..0787434c65 100644 --- a/tests/test_simplex_collection.cpp +++ b/tests/test_simplex_collection.cpp @@ -5,14 +5,16 @@ #include #include #include +#include +#include +#include #include #include #include #include -#include -#include #include #include +#include #include "tools/DEBUG_TriMesh.hpp" #include "tools/TriMesh_examples.hpp" @@ -537,9 +539,7 @@ TEST_CASE("simplex_closed_star", "[simplex_collection][2D]") for (size_t i = 7; i < 19; ++i) { const Simplex& e = simplices[i]; const Tuple center = m.switch_vertex(m.next_edge(e.tuple())); - CHECK( - (faces(m, e).contains(v) || - m.simplices_are_equal(v, Simplex::vertex(center)))); + CHECK((faces(m, e).contains(v) || m.simplices_are_equal(v, Simplex::vertex(center)))); } CHECK(m.id(simplices[19]) == 0); @@ -570,9 +570,7 @@ TEST_CASE("simplex_closed_star", "[simplex_collection][2D]") for (size_t i = 4; i < 9; ++i) { const Simplex& e = simplices[i]; const Tuple center = m.switch_vertex(m.next_edge(e.tuple())); - CHECK( - (faces(m, e).contains(v) || - m.simplices_are_equal(v, Simplex::vertex(center)))); + CHECK((faces(m, e).contains(v) || m.simplices_are_equal(v, Simplex::vertex(center)))); } CHECK(m.id(simplices[9]) == 0); @@ -855,3 +853,62 @@ TEST_CASE("simplex_link_iterable", "[simplex_collection][2D]") CHECK(m.simplices_are_equal(itrb_collection.simplex_vector()[i], coll.simplex_vector()[i])); } } + + +TEST_CASE("simplex_cofaces_single_dimension", "[simplex_collection][2D]") +{ + tests::DEBUG_TriMesh m = tests::hex_plus_two(); + + SECTION("vertex_interior") + { + const Tuple t = m.edge_tuple_between_v1_v2(4, 5, 2); + const simplex::Simplex input = simplex::Simplex::vertex(t); + std::vector tc = cofaces_single_dimension_tuples(m, input, PrimitiveType::Edge); + REQUIRE(tc.size() == 6); + + SimplexCollection sc( + m, + simplex::utils::tuple_vector_to_homogeneous_simplex_vector(tc, PrimitiveType::Face)); + sc.sort(); + const auto& cells = sc.simplex_vector(); + std::set target_vids({0, 3, 1, 5, 7, 8}); + std::set vids; + std::transform( + cells.begin(), + cells.end(), + std::inserter(vids, vids.end()), + [&](const Simplex& s) { + return m.id(m.switch_vertex(s.tuple()), PrimitiveType::Vertex); + }); + + CHECK(target_vids == vids); + + // check the lower dimension coface is the same as input + for (const Tuple& tup : tc) { + CHECK(m.id(tup, PrimitiveType::Vertex) == m.id(t, PrimitiveType::Vertex)); + } + } + + SECTION("vertex_boundary") + { + const Tuple t = m.edge_tuple_between_v1_v2(3, 4, 0); + const simplex::Simplex input = simplex::Simplex::vertex(t); + std::vector tc = cofaces_single_dimension_tuples(m, input, PrimitiveType::Edge); + REQUIRE(tc.size() == 3); + SimplexCollection sc( + m, + simplex::utils::tuple_vector_to_homogeneous_simplex_vector(tc, PrimitiveType::Face)); + sc.sort(); + + const auto& cells = sc.simplex_vector(); + + // check the lower dimension coface is the same as input + for (const Tuple& tup : tc) { + CHECK(m.id(tup, PrimitiveType::Vertex) == m.id(t, PrimitiveType::Vertex)); + } + + CHECK(m.id(m.switch_vertex(cells[0].tuple()), PrimitiveType::Vertex) == 0); + CHECK(m.id(m.switch_vertex(cells[1].tuple()), PrimitiveType::Vertex) == 4); + CHECK(m.id(m.switch_vertex(cells[2].tuple()), PrimitiveType::Vertex) == 7); + } +} diff --git a/tests/tools/TriMesh_examples.cpp b/tests/tools/TriMesh_examples.cpp index aa146d962b..036f7e29dc 100644 --- a/tests/tools/TriMesh_examples.cpp +++ b/tests/tools/TriMesh_examples.cpp @@ -1,5 +1,8 @@ #include "TriMesh_examples.hpp" +#include +#include #include +#include namespace wmtk::tests { @@ -13,19 +16,60 @@ TriMesh single_triangle() return m; } -TriMesh single_triangle_with_position() +TriMesh single_equilateral_triangle(int dimension) { + assert(dimension == 2 || dimension == 3); TriMesh m = single_triangle(); + Eigen::Matrix V; + + V.row(0) << 0., 0., 0; + V.row(1) << 1., 0, 0; + V.row(2) << 0.5, sqrt(3) / 2., 0; + +#if !defined(NDEBUG) + auto xt = V.row(0); + auto yt = V.row(1); + auto zt = V.row(2); + auto xth = xt.head<2>(); + auto yth = yt.head<2>(); + auto zth = zt.head<2>(); + auto x = xth.transpose(); + auto y = yth.transpose(); + auto z = zth.transpose(); + assert(wmtk::utils::triangle_signed_2d_area(x, y, z) >= 0); +#endif + + auto V2 = V.leftCols(dimension).eval(); + mesh_utils::set_matrix_attribute(V2, "position", PrimitiveType::Vertex, m); + return m; +} + +TriMesh single_2d_triangle_with_random_positions(size_t seed) +{ + TriMesh m = single_triangle(); + Eigen::Matrix V; + + std::mt19937 generator(seed); + std::uniform_real_distribution distribution(0., 1.); + + auto xt = V.row(0); + auto yt = V.row(1); + auto zt = V.row(2); + + auto x = xt.transpose(); + auto y = yt.transpose(); + auto z = zt.transpose(); + auto gen = [&](int, int) { return distribution(generator); }; + do { + V = Eigen::MatrixXd::NullaryExpr(V.rows(), V.cols(), gen); + } while (wmtk::utils::triangle_signed_2d_area(x, y, z) <= 0); + - Eigen::MatrixXd V; - V.resize(3, 3); - V.row(0) << 0, 0, 0; - V.row(1) << 1, 0, 0; - V.row(2) << 0.5, 0.866, 0; mesh_utils::set_matrix_attribute(V, "position", PrimitiveType::Vertex, m); return m; } + TriMesh quad() { TriMesh m; @@ -134,9 +178,9 @@ TriMesh interior_edge() TriMesh hex_plus_two() { // 0---1---2 - // / \ / \ / \ . + // /0\1/2\3/4\ . // 3---4---5---6 - // \ / \ / . + // \5/6\7/ . // 7---8 TriMesh m; RowVectors3l tris; @@ -309,6 +353,44 @@ TriMesh nine_triangles_with_a_hole() m.initialize(tris); return m; } + +TriMesh ten_triangles_with_position(int dimension) +{ + TriMesh m; + RowVectors3l tris; + tris.resize(10, 3); + tris.row(0) << 0, 1, 2; + tris.row(1) << 0, 2, 3; + tris.row(2) << 1, 4, 2; + tris.row(3) << 1, 6, 4; + tris.row(4) << 6, 7, 4; + tris.row(5) << 4, 7, 5; + tris.row(6) << 7, 8, 5; + tris.row(7) << 5, 8, 3; + tris.row(8) << 5, 3, 2; + tris.row(9) << 2, 4, 5; + m.initialize(tris); + + Eigen::MatrixXd V; + V.resize(9, 3); + V.row(0) << 0, 1, 0; + V.row(1) << -1, 0, 0; + V.row(2) << 0, 0, 0; + V.row(3) << 1, 0, 0; + V.row(4) << -0.8, -0.3, 0; + V.row(5) << 1, -1, 0; + V.row(6) << -3, -3, 0; + V.row(7) << 0, -3, 0; + V.row(8) << 1.5, -2, 0; + + if (dimension != 2 && dimension != 3) assert(false); + + V.conservativeResize(9, dimension); + + mesh_utils::set_matrix_attribute(V, "position", PrimitiveType::Vertex, m); + return m; +} + TriMesh three_individuals() { TriMesh m; diff --git a/tests/tools/TriMesh_examples.hpp b/tests/tools/TriMesh_examples.hpp index 45fc515080..88b6b6b3f1 100644 --- a/tests/tools/TriMesh_examples.hpp +++ b/tests/tools/TriMesh_examples.hpp @@ -14,8 +14,10 @@ namespace wmtk::tests { // TriMesh single_triangle(); -TriMesh single_triangle_with_position(); +TriMesh single_equilateral_triangle(int dimension = 3); +// a single triangle with position +TriMesh single_2d_triangle_with_random_positions(size_t seed = 123); // 3--1--- 0 // | / \ . // 2 f1 /2 1 @@ -137,6 +139,8 @@ TriMesh three_triangles_with_two_components(); // ⠀⠀⠀⠀⠀⠀⠀⠉⠉⠉⠉⠉⠙⠿⠋⠛7⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠒⠢⠤⠼⢦⡇⠀8⠀ TriMesh nine_triangles_with_a_hole(); +TriMesh ten_triangles_with_position(int dimension); + TriMesh edge_region_with_position(); // 1---2