diff --git a/src/orca-jedi/increment/Increment.cc b/src/orca-jedi/increment/Increment.cc index 4484b6d..8f13bf8 100644 --- a/src/orca-jedi/increment/Increment.cc +++ b/src/orca-jedi/increment/Increment.cc @@ -11,6 +11,7 @@ #include "atlas/array/MakeView.h" #include "atlas/field/Field.h" #include "atlas/field/FieldSet.h" +#include "atlas/field/MissingValue.h" #include "atlas/functionspace/StructuredColumns.h" #include "eckit/config/LocalConfiguration.h" @@ -25,11 +26,14 @@ #include "ufo/GeoVaLs.h" -#include "orca-jedi/errorcovariance/ErrorCovariance.h" #include "orca-jedi/geometry/Geometry.h" #include "orca-jedi/state/State.h" +#include "orca-jedi/state/StateIOUtils.h" #include "orca-jedi/increment/Increment.h" +#include "atlas/mesh.h" +#include "atlas-orca/grid/OrcaGrid.h" + namespace orcamodel { // ----------------------------------------------------------------------------- @@ -38,14 +42,24 @@ namespace orcamodel { Increment::Increment(const Geometry & geom, const oops::Variables & vars, const util::DateTime & time) - : geom_(new Geometry(geom)), vars_(), time_(time), + : geom_(new Geometry(geom)), vars_(vars), time_(time), incrementFields_() { - std::string err_message = - "orcamodel::Increment::constructor not implemented"; - throw eckit::NotImplemented(err_message, Here()); + if (geom_->getComm().size() != 1) { + throw eckit::NotImplemented("orcamodel::Increment::Increment: Cannot construct" + " an Increment with more than one MPI process."); + } + + incrementFields_ = atlas::FieldSet(); + + setupIncrementFields(); + + this->zero(); // may not be needed + + oops::Log::debug() << "Increment(ORCA)::Increment created for "<< validTime() + << std::endl; } -// ----------------------------------------------------------------------------- + Increment::Increment(const Geometry & geom, const Increment & other) : geom_(new Geometry(geom)), vars_(other.vars_), time_(other.time_), @@ -55,24 +69,511 @@ Increment::Increment(const Geometry & geom, "orcamodel::Increment::constructor(geom, other) not implemented"; throw eckit::NotImplemented(err_message, Here()); } -// ----------------------------------------------------------------------------- + +/// \brief Copy constructor. +/// \param other Increment to copy structure from. +/// \param copy Boolean flag copy contents if true. Increment::Increment(const Increment & other, const bool copy) : geom_(other.geom_), vars_(other.vars_), time_(other.time_), incrementFields_() { + oops::Log::debug() << "Increment(ORCA)::Increment copy " << copy << std::endl; + + incrementFields_ = atlas::FieldSet(); + + setupIncrementFields(); + + if (copy) { + for (size_t i=0; i < vars_.size(); ++i) { + // copy variable from _Fields to new field set + atlas::Field field = other.incrementFields_[i]; + oops::Log::debug() << "Copying increment field " << field.name() << std::endl; + incrementFields_->add(field); + } + } + + oops::Log::debug() << "Increment(ORCA)::Increment copied." << std::endl; + + oops::Log::debug() << "increment copy self print"; + print(oops::Log::debug()); + oops::Log::debug() << "increment copy other print"; + other.print(oops::Log::debug()); +} + +// Basic operators +Increment & Increment::operator=(const Increment & rhs) { + time_ = rhs.time_; + incrementFields_ = rhs.incrementFields_; + vars_ = rhs.vars_; + geom_.reset(); + geom_ = rhs.geom_; + + oops::Log::debug() << "Increment(ORCA)::= copy ended" << std::endl; + return *this; +} + +Increment & Increment::operator+=(const Increment & dx) { + ASSERT(this->validTime() == dx.validTime()); + + oops::Log::debug() << "increment add self print"; + print(oops::Log::debug()); + oops::Log::debug() << "increment add dx print"; + dx.print(oops::Log::debug()); + + auto ghost = atlas::array::make_view( + geom_->mesh().nodes().ghost()); + for (int i = 0; i< incrementFields_.size(); i++) + { + atlas::Field field = incrementFields_[i]; + atlas::Field field_dx = dx.incrementFields_[i]; + std::string fieldName = field.name(); + std::string fieldName_dx = field_dx.name(); + oops::Log::debug() << "orcamodel::Increment::add:: field name = " << fieldName + << " field name dx = " << fieldName_dx + << std::endl; + auto field_view = atlas::array::make_view(field); + auto field_view_dx = atlas::array::make_view(field_dx); + for (atlas::idx_t j = 0; j < field_view.shape(0); ++j) { + for (atlas::idx_t k = 0; k < field_view.shape(1); ++k) { + if (!ghost(j)) field_view(j, k) += field_view_dx(j, k); + } + } + } + + oops::Log::debug() << "increment add self print"; + print(oops::Log::debug()); + oops::Log::debug() << "increment add dx print"; + dx.print(oops::Log::debug()); + + oops::Log::debug() << "Increment(ORCA)::+ add ended" << std::endl; + return *this; +} + +Increment & Increment::operator-=(const Increment & dx) { + ASSERT(this->validTime() == dx.validTime()); + + auto ghost = atlas::array::make_view( + geom_->mesh().nodes().ghost()); + for (int i = 0; i< incrementFields_.size(); i++) + { + atlas::Field field = incrementFields_[i]; + atlas::Field field_dx = dx.incrementFields_[i]; + std::string fieldName = field.name(); + std::string fieldName_dx = field_dx.name(); + oops::Log::debug() << "orcamodel::Increment::subtract:: field name = " << fieldName + << " field name dx = " << fieldName_dx + << std::endl; + auto field_view = atlas::array::make_view(field); + auto field_view_dx = atlas::array::make_view(field_dx); + for (atlas::idx_t j = 0; j < field_view.shape(0); ++j) { + for (atlas::idx_t k = 0; k < field_view.shape(1); ++k) { + if (!ghost(j)) field_view(j, k) -= field_view_dx(j, k); + } + } + } + + oops::Log::debug() << "Increment(ORCA)::- subtract ended" << std::endl; + return *this; +} + +Increment & Increment::operator*=(const double & zz) { + oops::Log::debug() << "orcamodel::Increment:multiply start" << std::endl; + + auto ghost = atlas::array::make_view( + geom_->mesh().nodes().ghost()); + for (atlas::Field field : incrementFields_) { + std::string fieldName = field.name(); + oops::Log::debug() << "orcamodel::Increment::multiply:: field name = " << fieldName + << " zz " << zz + << std::endl; + auto field_view = atlas::array::make_view(field); + for (atlas::idx_t j = 0; j < field_view.shape(0); ++j) { + for (atlas::idx_t k = 0; k < field_view.shape(1); ++k) { + if (!ghost(j)) field_view(j, k) *= zz; + } + } + } + + oops::Log::debug() << "Increment(ORCA)::* multiplication ended" << std::endl; + return *this; +} + +/// \brief Create increment from the different of two state objects. +/// \param x1 State object. +/// \param x2 State object subtracted. +void Increment::diff(const State & x1, const State & x2) { + ASSERT(this->validTime() == x1.validTime()); + ASSERT(this->validTime() == x2.validTime()); + + auto ghost = atlas::array::make_view( + geom_->mesh().nodes().ghost()); + for (int i = 0; i< incrementFields_.size(); i++) + { + atlas::Field field1 = x1.getField(i); + atlas::Field field2 = x2.getField(i); + atlas::Field fieldi = incrementFields_[i]; + + std::string fieldName1 = field1.name(); + std::string fieldName2 = field2.name(); + std::string fieldNamei = fieldi.name(); + oops::Log::debug() << "orcamodel::Increment::diff:: field name 1 = " << fieldName1 + << " field name 2 = " << fieldName2 + << " field name inc = " << fieldNamei + << std::endl; + auto field_view1 = atlas::array::make_view(field1); + auto field_view2 = atlas::array::make_view(field2); + auto field_viewi = atlas::array::make_view(fieldi); + for (atlas::idx_t j = 0; j < field_viewi.shape(0); ++j) { + for (atlas::idx_t k = 0; k < field_viewi.shape(1); ++k) { + if (!ghost(j)) { field_viewi(j, k) = field_view1(j, k) - field_view2(j, k); + } else { field_viewi(j, k) = 0; } + } + } + } +} + +/// \brief Set increment fields to a uniform value. +/// \param val Value to use. +void Increment::setval(const double & val) { + oops::Log::trace() << "Increment(ORCA)::setval starting" << std::endl; + + auto ghost = atlas::array::make_view( + geom_->mesh().nodes().ghost()); + for (atlas::Field field : incrementFields_) { + std::string fieldName = field.name(); + oops::Log::debug() << "orcamodel::Increment::setval:: field name = " << fieldName + << "value " << val + << std::endl; + + atlas::field::MissingValue mv(incrementFields()[fieldName]); + bool has_mv = static_cast(mv); + + auto field_view = atlas::array::make_view(field); + for (atlas::idx_t j = 0; j < field_view.shape(0); ++j) { + for (atlas::idx_t k = 0; k < field_view.shape(1); ++k) { + if (!ghost(j)) { + if (!has_mv || (has_mv && !mv(field_view(j, k)))) { + field_view(j, k) = val; + } + } + } + } + } + + oops::Log::trace() << "Increment(ORCA)::setval done" << std::endl; +} + +void Increment::zero() { + oops::Log::trace() << "Increment(ORCA)::zero starting" << std::endl; + this->setval(0); + oops::Log::trace() << "Increment(ORCA)::zero done" << std::endl; +} + +void Increment::ones() { + oops::Log::trace() << "Increment(ORCA)::ones starting" << std::endl; + this->setval(1); + oops::Log::trace() << "Increment(ORCA)::ones done" << std::endl; +} + +void Increment::zero(const util::DateTime & vt) { + time_ = vt; + oops::Log::debug() << "orcamodel::Increment::zero at time " << vt << std::endl; + // NB currently no checking of the time just zeros everything + this->zero(); +} + +/// \brief multiply input increment object (x) by a scalar (a) and add onto self (y). +/// \param zz Scalar value (a). +/// \param dx Other increment object (x). +/// \param bool check Check (if true) the validity time of the increments fields matches. +void Increment::axpy(const double & zz, const Increment & dx, const bool check) { + ASSERT(!check || this->validTime() == dx.validTime()); + + auto ghost = atlas::array::make_view( + geom_->mesh().nodes().ghost()); + for (int i = 0; i< incrementFields_.size(); i++) + { + atlas::Field field = incrementFields_[i]; + atlas::Field field_dx = dx.incrementFields_[i]; + std::string fieldName = field.name(); + std::string fieldName_dx = field_dx.name(); + oops::Log::debug() << "orcamodel::Increment::axpy:: field name = " << fieldName + << " field name dx = " << fieldName_dx + << std::endl; + auto field_view = atlas::array::make_view(field); + auto field_view_dx = atlas::array::make_view(field_dx); + for (atlas::idx_t j = 0; j < field_view.shape(0); ++j) { + for (atlas::idx_t k = 0; k < field_view.shape(1); ++k) { + if (!ghost(j)) field_view(j, k) += zz * field_view_dx(j, k); + } + } + } +} + +/// \brief Dot product self increment object with another increment object +/// \param dx Other increment object. +double Increment::dot_product_with(const Increment & dx) const { + double zz = 0; + auto ghost = atlas::array::make_view( + geom_->mesh().nodes().ghost()); + // Deals with multiple fields + for (int i = 0; i< incrementFields_.size(); i++) + { + atlas::Field field = incrementFields_[i]; + atlas::Field field_dx = dx.incrementFields_[i]; + std::string fieldName = field.name(); + std::string fieldName_dx = field_dx.name(); + oops::Log::debug() << "orcamodel::Increment::dot_product_with:: field name = " << fieldName + << " field name dx = " << fieldName_dx + << std::endl; + auto field_view = atlas::array::make_view(field); + auto field_view_dx = atlas::array::make_view(field_dx); + for (atlas::idx_t j = 0; j < field_view.shape(0); ++j) { + for (atlas::idx_t k = 0; k < field_view.shape(1); ++k) { + if (!ghost(j)) zz += field_view(j, k) * field_view_dx(j, k); + } + } + } + + oops::Log::debug() << "orcamodel::Increment::dot_product_with ended :: zz = " << zz << std::endl; + + return zz; +} + +/// \brief Schur product self increment object with another increment object +/// \param dx Other increment object. +void Increment::schur_product_with(const Increment & dx) { + auto ghost = atlas::array::make_view( + geom_->mesh().nodes().ghost()); + for (int i = 0; i< incrementFields_.size(); i++) + { + atlas::Field field = incrementFields_[i]; + atlas::Field field_dx = dx.incrementFields_[i]; + std::string fieldName = field.name(); + std::string fieldName_dx = field_dx.name(); + oops::Log::debug() << "orcamodel::Increment::schur_product_with:: field name = " << fieldName + << " field name dx = " << fieldName_dx + << std::endl; + auto field_view = atlas::array::make_view(field); + auto field_view_dx = atlas::array::make_view(field_dx); + for (atlas::idx_t j = 0; j < field_view.shape(0); ++j) { + for (atlas::idx_t k = 0; k < field_view.shape(1); ++k) { + if (!ghost(j)) field_view(j, k) *= field_view_dx(j, k); + } + } + } +} + +/// \brief Initialise with a normally distributed random field with a mean of 0 and s.d. of 1. +void Increment::random() { + oops::Log::debug() << "orcamodel::Increment::random start" << std::endl; + oops::Log::debug() << "orcamodel::Increment::random seed_ " << seed_ << std::endl; + + auto ghost = atlas::array::make_view( + geom_->mesh().nodes().ghost()); + for (atlas::Field field : incrementFields_) { + std::string fieldName = field.name(); + oops::Log::debug() << "orcamodel::Increment::random:: field name = " << fieldName + << std::endl; + auto field_view = atlas::array::make_view(field); + // Seed currently hardwired in increment.h + util::NormalDistribution xx(field_view.shape(0)*field_view.shape(1), 0.0, 1.0, seed_); + int idx = 0; + for (atlas::idx_t j = 0; j < field_view.shape(0); ++j) { + for (atlas::idx_t k = 0; k < field_view.shape(1); ++k) { + if (!ghost(j)) field_view(j, k) = xx[idx]; + idx++; + } + } + } +} + +/// \brief Apply Dirac delta functions to configuration specified points. +void Increment::dirac(const eckit::Configuration & conf) { +// Adding a delta function at points specified by ixdir, iydir, izdir + const std::vector & ixdir = conf.getIntVector("ixdir"); + const std::vector & iydir = conf.getIntVector("iydir"); + const std::vector & izdir = conf.getIntVector("izdir"); + + ASSERT(ixdir.size() == iydir.size() && ixdir.size() == izdir.size()); + int ndir = ixdir.size(); + atlas::OrcaGrid orcaGrid = geom_->mesh().grid(); + int nx = orcaGrid.nx() + orcaGrid.haloWest() + orcaGrid.haloEast(); + std::vector jpt; + for (int i = 0; i < ndir; i++) { + jpt.push_back(iydir[i]*nx + ixdir[i]); + oops::Log::debug() << "orcamodel::Increment::dirac:: delta function " << i + << " at jpt = " << jpt[i] + << " kpt = " << izdir[i] << std::endl; + } + + auto ghost = atlas::array::make_view( + geom_->mesh().nodes().ghost()); + + this->zero(); + + for (atlas::Field field : incrementFields_) { + std::string fieldName = field.name(); + std::cout << "orcamodel::Increment::dirac:: field name = " << fieldName + << std::endl; + + auto field_view = atlas::array::make_view(field); + for (int i = 0; i < ndir; i++) { + if (!ghost(jpt[i])) { + field_view(jpt[i], izdir[i]) = 1; + } + } + } +} + +/// \brief Output increment fieldset as an atlas fieldset. +/// \param fset Atlas fieldset to output to. +void Increment::toFieldSet(atlas::FieldSet & fset) const { + oops::Log::debug() << "Increment toFieldSet starting" << std::endl; + + fset = atlas::FieldSet(); + + for (size_t i=0; i < vars_.size(); ++i) { + // copy variable from increments to new field set + atlas::Field fieldinc = incrementFields_[i]; + std::string fieldName = fieldinc.name(); + oops::Log::debug() << "Copy increment toFieldSet " << fieldName << std::endl; + + fset->add(fieldinc); + } + oops::Log::debug() << "Increment toFieldSet done" << std::endl; +} + +void Increment::toFieldSetAD(const atlas::FieldSet & fset) { + oops::Log::debug() << "Increment toFieldSetAD starting" << std::endl; + std::string err_message = - "orcamodel::Increment::constructor(other, copy) not implemented"; + "orcamodel::Increment::toFieldSetAD not implemented"; throw eckit::NotImplemented(err_message, Here()); + + oops::Log::debug() << "Increment toFieldSetAD done" << std::endl; } -// ----------------------------------------------------------------------------- -Increment::Increment(const Increment & other) - : geom_(other.geom_), vars_(other.vars_), time_(other.time_), - incrementFields_() -{ + +/// \brief Apply atlas fieldset to an increment fieldset. +/// \param fset Atlas fieldset to apply. +void Increment::fromFieldSet(const atlas::FieldSet & fset) { + oops::Log::debug() << "Increment fromFieldSet start" << std::endl; + + for (int i = 0; i< fset.size(); i++) { + atlas::Field field = fset[i]; + atlas::Field fieldinc = incrementFields_[i]; + oops::Log::debug() << "Increment fromFieldSet field " << i << " " << field.name() << std::endl; + oops::Log::debug() << "Increment fromFieldSet fieldinc " << i + << " " << fieldinc.name() << std::endl; + +// copy from field to incrementfields + + auto field_view_to = atlas::array::make_view(fieldinc); + auto field_view_from = atlas::array::make_view(field); + for (atlas::idx_t j = 0; j < field_view_to.shape(0); ++j) { + for (atlas::idx_t k = 0; k < field_view_to.shape(1); ++k) { + field_view_to(j, k) = field_view_from(j, k); + } + } + } + oops::Log::debug() << "Increment fromFieldSet done" << std::endl; +} + +/// \brief Setup variables and geometry for increment fields. +void Increment::setupIncrementFields() { + for (size_t i=0; i < vars_.size(); ++i) { + // add variable if it isn't already in incrementFields + std::vector varSizes = geom_->variableSizes(vars_); + if (!incrementFields_.has(vars_[i])) { + incrementFields_.add(geom_->functionSpace().createField( + atlas::option::name(vars_[i]) | + atlas::option::levels(varSizes[i]))); + oops::Log::trace() << "Increment(ORCA)::setupIncrementFields : " + << vars_[i] << "has dtype: " + << (*(incrementFields_.end()-1)).datatype().str() << std::endl; + geom_->log_status(); + } + } +} + +/// I/O and diagnostics +void Increment::read(const eckit::Configuration & conf) { std::string err_message = - "orcamodel::Increment::copy constructor not implemented"; + "orcamodel::Increment::read not implemented"; throw eckit::NotImplemented(err_message, Here()); } +void Increment::write(const eckit::Configuration & conf) const { + std::string err_message = + "orcamodel::Increment::write not implemented"; + throw eckit::NotImplemented(err_message, Here()); +} + +/// \brief Print some basic information about the self increment object. +void Increment::print(std::ostream & os) const { + oops::Log::trace() << "Increment(ORCA)::print starting" << std::endl; + + os << "Increment valid at time: " << validTime() << std::endl; + os << std::string(4, ' ') << vars_ << std::endl; + os << std::string(4, ' ') << "atlas field:" << std::endl; + for (atlas::Field field : incrementFields_) { + std::string fieldName = field.name(); + struct Increment::stats s = Increment::stats(fieldName); + os << std::string(8, ' ') << fieldName << + " num: " << s.valid_points << + " mean: " << std::setprecision(5) << s.sumx/s.valid_points << + " rms: " << sqrt(s.sumx2/s.valid_points) << + " min: " << s.min << " max: " << s.max << std::endl; + } + oops::Log::trace() << "Increment(ORCA)::print done" << std::endl; +} + +/// \brief Calculate some basic statistics of a field in the increment object. +/// \param fieldName Name of the field to use. +struct Increment::stats Increment::stats(const std::string & fieldName) const { + struct Increment::stats s; + s.valid_points = 0; + s.sumx = 0; + s.sumx2 = 0; + s.min = 1e30; + s.max = -1e30; + + auto field_view = atlas::array::make_view( + incrementFields_[fieldName]); + oops::Log::trace() << "Increment(ORCA):stats" << std::endl; + auto ghost = atlas::array::make_view( + geom_->mesh().nodes().ghost()); + atlas::field::MissingValue mv(incrementFields()[fieldName]); + bool has_mv = static_cast(mv); + for (atlas::idx_t j = 0; j < field_view.shape(0); ++j) { + for (atlas::idx_t k = 0; k < field_view.shape(1); ++k) { + if (!ghost(j)) { + if (!has_mv || (has_mv && !mv(field_view(j, k)))) { + if (field_view(j, k) > s.max) { s.max=field_view(j, k); } + if (field_view(j, k) < s.min) { s.min=field_view(j, k); } + s.sumx += field_view(j, k); + s.sumx2 += field_view(j, k)*field_view(j, k); + ++s.valid_points; + } + } + } + } + return s; +} + +/// \brief Output norm (RMS) of the self increment fields. +double Increment::norm() const { + int valid_points_all = 0; + double sumx2all = 0; + + for (atlas::Field field : incrementFields_) { + std::string fieldName = field.name(); + struct Increment::stats s = Increment::stats(fieldName); + sumx2all += s.sumx2; + valid_points_all += s.valid_points; + } + // return RMS + return sqrt(sumx2all/valid_points_all); +} } // namespace orcamodel diff --git a/src/orca-jedi/increment/Increment.h b/src/orca-jedi/increment/Increment.h index 73e7876..ce2f47d 100644 --- a/src/orca-jedi/increment/Increment.h +++ b/src/orca-jedi/increment/Increment.h @@ -40,7 +40,6 @@ namespace oops { namespace orcamodel { class Geometry; class ModelBiasIncrement; - class ErrorCovariance; class State; /// orcaModel Increment Class: Difference between two states @@ -64,29 +63,43 @@ class Increment : public util::Printable, const oops::Variables &, const util::DateTime &); Increment(const Geometry &, const Increment &); - Increment(const Increment &, const bool); - Increment(const Increment &); + Increment(const Increment &, const bool copy = true); /// Basic operators - void diff(const State &, const State &) {} - void zero() {} - void zero(const util::DateTime &) {} - void ones() {} - // Increment & operator =(const Increment &); - // Increment & operator+=(const Increment &); - // Increment & operator-=(const Increment &); - // Increment & operator*=(const double &); - void axpy(const double &, const Increment &, const bool check = true) {} - double dot_product_with(const Increment &) const {return 0.0; } - void schur_product_with(const Increment &) {} - void random() {} - void dirac(const eckit::Configuration &) {} + void diff(const State &, const State &); + void zero(); + void zero(const util::DateTime &); + void ones(); + Increment & operator =(const Increment &); + Increment & operator+=(const Increment &); + Increment & operator-=(const Increment &); + Increment & operator*=(const double &); + void axpy(const double &, const Increment &, const bool check = true); + double dot_product_with(const Increment &) const; + void schur_product_with(const Increment &); + void random(); + void dirac(const eckit::Configuration &); + +/// ATLAS + void toFieldSet(atlas::FieldSet &) const; + void toFieldSetAD(const atlas::FieldSet &); + void fromFieldSet(const atlas::FieldSet &); /// I/O and diagnostics - void read(const eckit::Configuration &) {} - void write(const eckit::Configuration &) const {} - double norm() const {return 0.0; } - void print(std::ostream & os) const override {os << "Not Implemented";} + + struct stats { + int valid_points; + double sumx; + double sumx2; + double min; + double max; + }; + + void read(const eckit::Configuration &); + void write(const eckit::Configuration &) const; + void print(std::ostream & os) const override; + struct stats stats(const std::string & field_name) const; + double norm() const; void updateTime(const util::Duration & dt) {time_ += dt;} @@ -103,7 +116,6 @@ class Increment : public util::Printable, /// Other void accumul(const double &, const State &) {} - /// Utilities std::shared_ptr geometry() const {return geom_;} @@ -114,17 +126,17 @@ class Increment : public util::Printable, const atlas::FieldSet & incrementFields() const {return incrementFields_;} atlas::FieldSet & incrementFields() {return incrementFields_;} - const oops::Variables & variables() const {return vars_;} - /// Data private: - // void print(std::ostream &) const override; + void setupIncrementFields(); + void setval(const double &); std::shared_ptr geom_; oops::Variables vars_; util::DateTime time_; atlas::FieldSet incrementFields_; + int seed_ = 7; }; // ----------------------------------------------------------------------------- diff --git a/src/orca-jedi/state/State.cc b/src/orca-jedi/state/State.cc index 547208b..29b0c89 100644 --- a/src/orca-jedi/state/State.cc +++ b/src/orca-jedi/state/State.cc @@ -347,4 +347,8 @@ template double State::norm(const std::string & field_name) const { template double State::norm(const std::string & field_name) const; template double State::norm(const std::string & field_name) const; +atlas::Field State::getField(int i) const { + return stateFields_[i]; +} + } // namespace orcamodel diff --git a/src/orca-jedi/state/State.h b/src/orca-jedi/state/State.h index 863f033..5989237 100644 --- a/src/orca-jedi/state/State.h +++ b/src/orca-jedi/state/State.h @@ -105,6 +105,8 @@ class State : public util::Printable, const oops::Variables & variables() const {return vars_;} oops::Variables & variables() {return vars_;} + atlas::Field getField(int) const; + private: void setupStateFields(); void print(std::ostream &) const override; diff --git a/src/tests/orca-jedi/test_increment.cc b/src/tests/orca-jedi/test_increment.cc index 8ad1cae..7a91a25 100644 --- a/src/tests/orca-jedi/test_increment.cc +++ b/src/tests/orca-jedi/test_increment.cc @@ -2,6 +2,8 @@ * (C) British Crown Copyright 2024 Met Office */ +#include + #include "eckit/log/Bytes.h" #include "eckit/config/LocalConfiguration.h" #include "eckit/mpi/Comm.h" @@ -10,6 +12,7 @@ #include "oops/base/Variables.h" +#include "atlas/field/FieldSet.h" #include "atlas/library/Library.h" #include "orca-jedi/increment/Increment.h" @@ -21,22 +24,25 @@ namespace test { //----------------------------------------------------------------------------- -CASE("test create increment") { +CASE("test increment") { EXPECT(eckit::system::Library::exists("atlas-orca")); eckit::LocalConfiguration config; std::vector nemo_var_mappings(4); nemo_var_mappings[0].set("name", "sea_ice_area_fraction") + .set("field precision", "double") .set("nemo field name", "iiceconc") .set("model space", "surface"); nemo_var_mappings[1].set("name", "sea_ice_area_fraction_error") .set("nemo field name", "sic_tot_var") + .set("field precision", "double") .set("model space", "surface"); nemo_var_mappings[2].set("name", "sea_surface_foundation_temperature") .set("nemo field name", "votemper") .set("model space", "surface"); nemo_var_mappings[3].set("name", "sea_water_potential_temperature") .set("nemo field name", "votemper") + .set("field precision", "double") .set("model space", "volume"); config.set("nemo variables", nemo_var_mappings); config.set("grid name", "ORCA2_T"); @@ -44,14 +50,117 @@ CASE("test create increment") { Geometry geometry(config, eckit::mpi::comm()); const std::vector channels{}; - std::vector varnames {"sea_ice_area_fraction", + std::vector varnames2 {"sea_ice_area_fraction", "sea_water_potential_temperature"}; + oops::Variables oops_vars2(varnames2, channels); + + std::vector varnames {"sea_ice_area_fraction"}; oops::Variables oops_vars(varnames, channels); - util::DateTime datetime; + util::DateTime datetime("2021-06-30T00:00:00Z"); + + SECTION("test constructor") { + std::cout << "----------------" << std::endl; + Increment increment(geometry, oops_vars2, datetime); + // copy + Increment increment2 = increment; + } + + SECTION("test setting increment value") { + std::cout << "----------------------------" << std::endl; + Increment increment(geometry, oops_vars, datetime); + std::cout << std::endl << "Increment ones: " << std::endl; + increment.ones(); + increment.print(std::cout); + EXPECT_EQUAL(increment.norm(), 1); + std::cout << std::endl << "Increment zero: " << std::endl; + increment.zero(); + increment.print(std::cout); + EXPECT_EQUAL(increment.norm(), 0); + std::cout << std::endl << "Increment random: " << std::endl; + increment.random(); + increment.print(std::cout); + } + + SECTION("test dirac") { + std::cout << "----------------" << std::endl; + eckit::LocalConfiguration dirac_config; + std::vector ix = {20, 30}; + std::vector iy = {10, 40}; + std::vector iz = {1, 3}; + dirac_config.set("ixdir", ix); + dirac_config.set("iydir", iy); + dirac_config.set("izdir", iz); + + Increment increment(geometry, oops_vars, datetime); + increment.dirac(dirac_config); + increment.print(std::cout); + EXPECT(std::abs(increment.norm() - 0.0086788) < 1e-6); + } + + SECTION("test mathematical operators") { + std::cout << "---------------------------" << std::endl; + Increment increment1(geometry, oops_vars, datetime); + Increment increment2(geometry, oops_vars, datetime); + increment1.ones(); + std::cout << std::endl << "Increment1:" << std::endl; + increment1.print(std::cout); + increment1 *= 3; + std::cout << std::endl << "Increment1 (*3):" << std::endl; + increment1.print(std::cout); + increment2.ones(); + increment2 *= 2; + std::cout << std::endl << "Increment2 (*2):" << std::endl; + increment2.print(std::cout); + increment1 -= increment2; + std::cout << std::endl << "Increment1 (-increment2):" << std::endl; + increment1.print(std::cout); + increment1 += increment2; + increment1 += increment2; + std::cout << std::endl << "Increment1 (+increment2*2):" << std::endl; + increment1.print(std::cout); + double zz = increment1.dot_product_with(increment2); + std::cout << std::endl << "Dot product increment1.increment2 = " << zz << std::endl; + increment1.schur_product_with(increment2); + std::cout << std::endl << "Increment1 (Schur product with increment 2):" << std::endl; + increment1.print(std::cout); + increment1.axpy(100, increment2, true); + std::cout << std::endl << "Increment1 axpy (increment2*100 + increment1):" << std::endl; + increment1.print(std::cout); + EXPECT_EQUAL(increment1.norm(), 210); + } + + SECTION("test increments to fieldset and back to increments") { + std::cout << "--------------------------------------------------" << std::endl; + Increment increment1(geometry, oops_vars, datetime); + increment1.ones(); + Increment increment2(geometry, oops_vars, datetime); + increment2.zero(); + atlas::FieldSet incfset = atlas::FieldSet(); + increment1.Increment::toFieldSet(incfset); + increment2.Increment::fromFieldSet(incfset); + increment1.print(std::cout); + increment2.print(std::cout); + EXPECT_EQUAL(increment1.norm(), increment2.norm()); + } - EXPECT_THROWS_AS(Increment increment(geometry, oops_vars, datetime), - eckit::NotImplemented); + SECTION("test increment diff with state inputs") { + std::cout << "-------------------------------------" << std::endl; + // Using the same variables and double type as the increments + // Code to deal with differing variables in state and increment not currently implemented + orcamodel::State state1(geometry, oops_vars, datetime); + orcamodel::State state2(geometry, oops_vars, datetime); + state1.zero(); + state2.zero(); + std::cout << "state1 norm:" << varnames[0]; + std::cout << state1.norm(varnames[0]) << std::endl; + std::cout << "state2 norm:" << varnames[0]; + std::cout << state2.norm(varnames[0]) << std::endl; + Increment increment(geometry, oops_vars, datetime); + increment.diff(state1, state2); + std::cout << "increment (diff state1 state2):" << std::endl; + increment.print(std::cout); + } } } // namespace test diff --git a/src/tests/orca-jedi/test_state.cc b/src/tests/orca-jedi/test_state.cc index 1dec9d6..0c04850 100644 --- a/src/tests/orca-jedi/test_state.cc +++ b/src/tests/orca-jedi/test_state.cc @@ -123,6 +123,12 @@ CASE("test basic state") { SECTION("test state write") { state.write(params); } + + SECTION("test state getField") { + atlas::Field field = state.getField(0); + std::cout << field.name() << std::endl; + EXPECT(field.name() == "sea_ice_area_fraction"); + } } } // namespace test