Skip to content

Commit

Permalink
Use a variant comparer and std::set to finish array predicate
Browse files Browse the repository at this point in the history
Signed-off-by: Kunlin Yu <[email protected]>
  • Loading branch information
kunlinyu committed Dec 31, 2024
1 parent d9c16a8 commit af7998e
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 35 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ set(CQL2CPP_SRC
src/id_generator.cc
src/ast_node.cc
src/global_yylex.cc
src/value.cc
${FLEX_OUTPUT}
${BISON_OUTPUT}
)
Expand Down
100 changes: 72 additions & 28 deletions include/cql2cpp/node_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@

#pragma once

#include <geos/geom/Envelope.h>
#include <geos/geom/Geometry.h>
#include <geos/geom/GeometryFactory.h>

#include <cmath>
#include <variant>

#include "evaluator.h"
#include "geos/geom/Envelope.h"
#include "geos/geom/Geometry.h"
#include "geos/geom/GeometryFactory.h"
#include "value_compare.h"

namespace cql2cpp {

constexpr double kEpsilon = 1e-5;

template <typename ValueType>
bool CheckValueNumberType(const std::string& op, size_t num,
const std::vector<ValueT>& vs, std::string* errmsg) {
Expand All @@ -37,29 +37,6 @@ bool CheckValueNumberType(const std::string& op, size_t num,
return true;
}

template <typename T>
inline bool TypedEqual(const ValueT& a, const ValueT& b) {
return std::get<T>(a) == std::get<T>(b);
}

inline bool isVariantEqual(const ValueT& a, const ValueT& b) {
if (a.index() != b.index()) return false;

if (std::holds_alternative<bool>(a)) return TypedEqual<bool>(a, b);

if (std::holds_alternative<int64_t>(a)) return TypedEqual<int64_t>(a, b);

if (std::holds_alternative<uint64_t>(a)) return TypedEqual<uint64_t>(a, b);

if (std::holds_alternative<double>(a))
return fabs(std::get<double>(a) - std::get<double>(b)) < kEpsilon;

if (std::holds_alternative<std::string>(a))
return TypedEqual<std::string>(a, b);

return false;
}

const std::map<NodeType, std::map<Operator, NodeEval>> node_evals = {
{Literal,
{{NullOp,
Expand Down Expand Up @@ -145,6 +122,73 @@ const std::map<NodeType, std::map<Operator, NodeEval>> node_evals = {
return ret;
}},
}},
{Array,
{{NullOp,
[](auto n, auto vs, auto fs, auto value, auto errmsg) -> bool {
ArrayType result;
for (const auto& v : vs) result.insert(ArrayElement(v));
*value = result;
return true;
}}}},
{
ArrayPred,
{{A_Equals,
[](auto n, auto vs, auto fs, auto value, auto errmsg) -> bool {
const auto& contains = node_evals.at(ArrayPred).at(A_Contains);
const auto& contained = node_evals.at(ArrayPred).at(A_ContainedBy);

bool ret1 = contains.operator()(n, vs, fs, value, errmsg);
bool result_1 = std::get<bool>(*value);
if (not ret1) return false;
bool ret2 = contained.operator()(n, vs, fs, value, errmsg);
bool result_2 = std::get<bool>(*value);
if (not ret2) return false;

*value = result_1 and result_2;
return true;
}},
{A_Contains,
[](auto n, auto vs, auto fs, auto value, auto errmsg) -> bool {
if (not CheckValueNumberType<ArrayType>("Array Op", 2, vs, errmsg))
return false;
const auto& lhs = std::get<ArrayType>(vs.at(0));
const auto& rhs = std::get<ArrayType>(vs.at(1));
if (lhs.size() < rhs.size()) {
*value = false;
return true;
}
for (const auto& e : rhs)
if (lhs.find(e) == lhs.end()) {
*value = false;
return true;
}
*value = true;
return true;
}},
{A_ContainedBy,
[](auto n, auto vs, auto fs, auto value, auto errmsg) -> bool {
if (not CheckValueNumberType<ArrayType>("Array Op", 2, vs, errmsg))
return false;

const auto& lhs = std::get<ArrayType>(vs.at(0));
const auto& rhs = std::get<ArrayType>(vs.at(1));
if (lhs.size() > rhs.size()) {
*value = false;
return true;
}
for (const auto& e : lhs)
if (rhs.find(e) == rhs.end()) {
*value = false;
return true;
}
*value = true;
return true;
}},
{A_Overlaps,
[](auto n, auto vs, auto fs, auto value, auto errmsg) -> bool {
return true;
}}},
},
{SpatialPred,
{{S_Intersects,
[](auto n, auto vs, auto fs, auto value, auto errmsg) -> bool {
Expand Down
4 changes: 2 additions & 2 deletions include/cql2cpp/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ enum Operator {
T_Starts,

// array operators
A_ContainedBy,
A_Contains,
A_Equals,
A_Contains,
A_ContainedBy,
A_Overlaps,
};

Expand Down
45 changes: 40 additions & 5 deletions include/cql2cpp/value.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* File Name: value_type.h
* File Name: value.h
*
* Copyright (c) 2024 IndoorSpatial
*
Expand All @@ -15,19 +15,44 @@
#include <geos/io/WKTWriter.h>

#include <cctype>
#include <set>
#include <sstream>
#include <string>
#include <variant>

namespace cql2cpp {

enum NullStruct { NullValue };

typedef std::variant<NullStruct, bool, int64_t, uint64_t, double, std::string,
const geos::geom::Geometry*, const geos::geom::Envelope*>
ValueT;
struct ArrayElement;

class ArrayElementComp {
public:
static constexpr double tolerance = 1e-9;

public:
bool operator()(const ArrayElement& a, const ArrayElement& b) const;

private:
template <typename T, typename U>
bool less(const U& lhs, const U& rhs) const {
return std::less()(std::get<T>(lhs), std::get<T>(rhs));
}
};

using ArrayType = std::set<ArrayElement, ArrayElementComp>;

using ValueT = std::variant<NullStruct, bool, int64_t, uint64_t, double,
std::string, ArrayType, const geos::geom::Geometry*,
const geos::geom::Envelope*>;

struct ArrayElement {
ValueT value;
ArrayElement(const ValueT& value) : value(value) {}
};

static std::string value_str(ValueT value, bool with_type = false) {
if (std::holds_alternative<NullStruct>(value)) return "?";
if (std::holds_alternative<NullStruct>(value)) return "null";

if (std::holds_alternative<bool>(value))
return (std::get<bool>(value) ? "T" : "F") +
Expand All @@ -49,6 +74,16 @@ static std::string value_str(ValueT value, bool with_type = false) {
return std::get<std::string>(value) +
std::string(with_type ? " string" : "");

if (std::holds_alternative<ArrayType>(value)) {
std::stringstream ss;
ss << "[";
for (const auto& element : std::get<ArrayType>(value))
ss << value_str(element.value) << ",";
ss.seekp(-1, std::ios_base::end);
ss << "]";
return ss.str() + std::string(with_type ? " array" : "");
}

if (std::holds_alternative<const geos::geom::Geometry*>(value)) {
geos::io::WKTWriter writer;
return writer.write(std::get<const geos::geom::Geometry*>(value));
Expand Down
44 changes: 44 additions & 0 deletions include/cql2cpp/value_compare.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* File Name: value_compare.h
*
* Copyright (c) 2024 IndoorSpatial
*
* Author: Kunlin Yu <[email protected]>
* Create Date: 2024/12/31
*
*/

#pragma once

// #include <functional>

#include "value.h"

namespace cql2cpp {

static const double kEpsilon = 1e-5;

template <typename T>
inline bool TypedEqual(const ValueT& a, const ValueT& b) {
return std::get<T>(a) == std::get<T>(b);
}

inline bool isVariantEqual(const ValueT& a, const ValueT& b) {
if (a.index() != b.index()) return false;

if (std::holds_alternative<bool>(a)) return TypedEqual<bool>(a, b);

if (std::holds_alternative<int64_t>(a)) return TypedEqual<int64_t>(a, b);

if (std::holds_alternative<uint64_t>(a)) return TypedEqual<uint64_t>(a, b);

if (std::holds_alternative<double>(a))
return fabs(std::get<double>(a) - std::get<double>(b)) < kEpsilon;

if (std::holds_alternative<std::string>(a))
return TypedEqual<std::string>(a, b);

return false;
}

} // namespace cql2cpp
37 changes: 37 additions & 0 deletions src/value.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* File Name: value_type.cc
*
* Copyright (c) 2024 IndoorSpatial
*
* Author: Kunlin Yu <[email protected]>
* Create Date: 2024/12/31
*
*/

#include <cql2cpp/value.h>

namespace cql2cpp {

bool ArrayElementComp::operator()(const ArrayElement& lhs,
const ArrayElement& rhs) const {
const auto& a = lhs.value;
const auto& b = rhs.value;
if (a.index() != b.index()) return std::less()(a.index(), b.index());

if (std::holds_alternative<NullStruct>(a)) return less<NullStruct>(a, b);
if (std::holds_alternative<bool>(a)) return less<bool>(a, b);
if (std::holds_alternative<int64_t>(a)) return less<int64_t>(a, b);
if (std::holds_alternative<uint64_t>(a)) return less<uint64_t>(a, b);
if (std::holds_alternative<std::string>(a)) return less<std::string>(a, b);

if (std::holds_alternative<double>(a)) {
double d_a = std::get<double>(a);
double d_b = std::get<double>(b);
if (std::fabs(d_a - d_b) < tolerance) return false;
return std::less()(d_a, d_b);
}

return true;
}

} // namespace cql2cpp

0 comments on commit af7998e

Please sign in to comment.