-
-
Notifications
You must be signed in to change notification settings - Fork 188
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #205 from stan-dev/feature/issue-204-unit_transform
Fixes #204. Feature/issue 204 unit transform
- Loading branch information
Showing
12 changed files
with
399 additions
and
93 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
#ifndef STAN_MATH_FWD_MAT_FUN_UNIT_VECTOR_CONSTRAIN_HPP | ||
#define STAN_MATH_FWD_MAT_FUN_UNIT_VECTOR_CONSTRAIN_HPP | ||
|
||
#include <stan/math/fwd/core.hpp> | ||
#include <stan/math/fwd/mat/fun/divide.hpp> | ||
#include <stan/math/fwd/mat/fun/dot_self.hpp> | ||
#include <stan/math/fwd/mat/fun/tcrossprod.hpp> | ||
#include <stan/math/fwd/scal/fun/sqrt.hpp> | ||
#include <stan/math/prim/mat/fun/divide.hpp> | ||
#include <stan/math/prim/mat/fun/Eigen.hpp> | ||
#include <stan/math/prim/mat/fun/tcrossprod.hpp> | ||
#include <stan/math/prim/mat/fun/unit_vector_constrain.hpp> | ||
#include <stan/math/prim/scal/fun/inv.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
template <typename T, int R, int C> | ||
inline Eigen::Matrix<fvar<T>, R, C> | ||
unit_vector_constrain(const Eigen::Matrix<fvar<T>, R, C>& y) { | ||
using std::sqrt; | ||
using Eigen::Matrix; | ||
|
||
Matrix<T, R, C> y_t(y.size()); | ||
for (int k = 0; k < y.size(); ++k) | ||
y_t.coeffRef(k) = y.coeff(k).val_; | ||
|
||
Matrix<T, R, C> unit_vector_y_t | ||
= unit_vector_constrain(y_t); | ||
Matrix<fvar<T>, R, C> unit_vector_y(y.size()); | ||
for (int k = 0; k < y.size(); ++k) | ||
unit_vector_y.coeffRef(k).val_ = unit_vector_y_t.coeff(k); | ||
|
||
const T squared_norm = dot_self(y_t); | ||
const T norm = sqrt(squared_norm); | ||
const T inv_norm = inv(norm); | ||
Matrix<T, Eigen::Dynamic, Eigen::Dynamic> J | ||
= divide(tcrossprod(y_t), -norm * squared_norm); | ||
|
||
// for each input position | ||
for (int m = 0; m < y.size(); ++m) { | ||
J.coeffRef(m, m) += inv_norm; | ||
// for each output position | ||
for (int k = 0; k < y.size(); ++k) { | ||
// chain from input to output | ||
unit_vector_y.coeffRef(k).d_ = J.coeff(k, m); | ||
} | ||
} | ||
return unit_vector_y; | ||
} | ||
|
||
template <typename T, int R, int C> | ||
inline Eigen::Matrix<fvar<T>, R, C> | ||
unit_vector_constrain(const Eigen::Matrix<fvar<T>, R, C>& y, fvar<T>& lp) { | ||
const fvar<T> squared_norm = dot_self(y); | ||
lp -= 0.5 * squared_norm; | ||
return unit_vector_constrain(y); | ||
} | ||
|
||
} | ||
} | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
#ifndef STAN_MATH_PRIM_MAT_FUN_UNIT_VECTOR_CONSTRAIN_HPP | ||
#define STAN_MATH_PRIM_MAT_FUN_UNIT_VECTOR_CONSTRAIN_HPP | ||
|
||
#include <stan/math/prim/mat/fun/Eigen.hpp> | ||
#include <stan/math/prim/mat/fun/dot_self.hpp> | ||
#include <stan/math/prim/mat/err/check_vector.hpp> | ||
#include <stan/math/prim/scal/err/check_positive_finite.hpp> | ||
#include <stan/math/prim/scal/err/check_nonzero_size.hpp> | ||
#include <stan/math/rev/core.hpp> | ||
#include <stan/math/rev/mat/fun/dot_self.hpp> | ||
#include <cmath> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
namespace { | ||
class unit_vector_elt_vari : public vari { | ||
private: | ||
vari** y_; | ||
const double* unit_vector_y_; | ||
const int size_; | ||
const int idx_; | ||
const double norm_; | ||
|
||
public: | ||
unit_vector_elt_vari(double val, | ||
vari** y, | ||
const double* unit_vector_y, | ||
int size, | ||
int idx, | ||
const double norm) | ||
: vari(val), | ||
y_(y), | ||
unit_vector_y_(unit_vector_y), | ||
size_(size), | ||
idx_(idx), | ||
norm_(norm) { | ||
} | ||
void chain() { | ||
const double cubed_norm = norm_ * norm_ * norm_; | ||
for (int m = 0; m < size_; ++m) { | ||
y_[m]->adj_ | ||
-= adj_ * unit_vector_y_[m] * unit_vector_y_[idx_] / cubed_norm; | ||
if (m == idx_) | ||
y_[m]->adj_ += adj_ / norm_; | ||
} | ||
} | ||
}; | ||
} | ||
|
||
|
||
// Unit vector | ||
|
||
/** | ||
* Return the unit length vector corresponding to the free vector y. | ||
* See https://en.wikipedia.org/wiki/N-sphere#Generating_random_points | ||
* | ||
* @param y vector of K unrestricted variables | ||
* @return Unit length vector of dimension K | ||
* @tparam T Scalar type. | ||
**/ | ||
template <int R, int C> | ||
Eigen::Matrix<var, R, C> | ||
unit_vector_constrain(const Eigen::Matrix<var, R, C>& y) { | ||
stan::math::check_vector("unit_vector", "y", y); | ||
stan::math::check_nonzero_size("unit_vector", "y", y); | ||
|
||
vari** y_vi_array | ||
= reinterpret_cast<vari**>(ChainableStack::memalloc_ | ||
.alloc(sizeof(vari*) * y.size())); | ||
for (int i = 0; i < y.size(); ++i) | ||
y_vi_array[i] = y.coeff(i).vi_; | ||
|
||
Eigen::VectorXd y_d(y.size()); | ||
for (int i = 0; i < y.size(); ++i) | ||
y_d.coeffRef(i) = y.coeff(i).val(); | ||
|
||
|
||
const double norm = y_d.norm(); | ||
stan::math::check_positive_finite("unit_vector", "norm", norm); | ||
Eigen::VectorXd unit_vector_d = y_d / norm; | ||
|
||
double* unit_vector_y_d_array | ||
= reinterpret_cast<double*>(ChainableStack::memalloc_ | ||
.alloc(sizeof(double) * y_d.size())); | ||
for (int i = 0; i < y_d.size(); ++i) | ||
unit_vector_y_d_array[i] = unit_vector_d.coeff(i); | ||
|
||
Eigen::Matrix<var, R, C> unit_vector_y(y.size()); | ||
for (int k = 0; k < y.size(); ++k) | ||
unit_vector_y.coeffRef(k) | ||
= var(new unit_vector_elt_vari(unit_vector_d[k], | ||
y_vi_array, | ||
unit_vector_y_d_array, | ||
y.size(), | ||
k, | ||
norm)); | ||
return unit_vector_y; | ||
} | ||
|
||
/** | ||
* Return the unit length vector corresponding to the free vector y. | ||
* See https://en.wikipedia.org/wiki/N-sphere#Generating_random_points | ||
* | ||
* @param y vector of K unrestricted variables | ||
* @return Unit length vector of dimension K | ||
* @param lp Log probability reference to increment. | ||
* @tparam T Scalar type. | ||
**/ | ||
template <int R, int C> | ||
Eigen::Matrix<var, R, C> | ||
unit_vector_constrain(const Eigen::Matrix<var, R, C>& y, var &lp) { | ||
Eigen::Matrix<var, R, C> x = unit_vector_constrain(y); | ||
lp -= 0.5 * stan::math::dot_self(y); | ||
return x; | ||
} | ||
|
||
} | ||
|
||
} | ||
|
||
#endif |
Oops, something went wrong.