Skip to content

Commit

Permalink
Draft STDEV() aggregate function
Browse files Browse the repository at this point in the history
  • Loading branch information
ullingerc committed Oct 14, 2024
1 parent 414f50c commit fc56318
Show file tree
Hide file tree
Showing 16 changed files with 2,656 additions and 2,420 deletions.
2 changes: 2 additions & 0 deletions src/engine/GroupBy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "engine/sparqlExpressions/SampleExpression.h"
#include "engine/sparqlExpressions/SparqlExpression.h"
#include "engine/sparqlExpressions/SparqlExpressionGenerators.h"
#include "engine/sparqlExpressions/StdevExpression.h"
#include "global/RuntimeParameters.h"
#include "index/Index.h"
#include "index/IndexImpl.h"
Expand Down Expand Up @@ -1013,6 +1014,7 @@ GroupBy::isSupportedAggregate(sparqlExpression::SparqlExpression* expr) {
if (auto val = dynamic_cast<GroupConcatExpression*>(expr)) {
return H{GROUP_CONCAT, val->getSeparator()};
}
if (dynamic_cast<StdevExpression*>(expr)) return H{STDEV};
if (dynamic_cast<SampleExpression*>(expr)) return H{SAMPLE};

// `expr` is an unsupported aggregate
Expand Down
1 change: 1 addition & 0 deletions src/engine/GroupBy.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ class GroupBy : public Operation {
MAX,
SUM,
GROUP_CONCAT,
STDEV,
SAMPLE
};

Expand Down
5 changes: 5 additions & 0 deletions src/engine/sparqlExpressions/AggregateExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "engine/sparqlExpressions/AggregateExpression.h"

#include "engine/sparqlExpressions/GroupConcatExpression.h"
#include "engine/sparqlExpressions/StdevExpression.h"

namespace sparqlExpression::detail {

Expand Down Expand Up @@ -180,6 +181,10 @@ AggregateExpression<AggregateOperation, FinalOperation>::getVariableForCount()
// Explicit instantiation for the AVG expression.
template class AggregateExpression<AvgOperation, decltype(avgFinalOperation)>;

// Explicit instantiation for the STDEV expression.
template class DeviationAggExpression<AvgOperation,
decltype(stdevFinalOperation)>;

// Explicit instantiations for the other aggregate expressions.
#define INSTANTIATE_AGG_EXP(Function, ValueGetter) \
template class AggregateExpression< \
Expand Down
163 changes: 163 additions & 0 deletions src/engine/sparqlExpressions/StdevExpression.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
// Copyright 2024, University of Freiburg,
// Chair of Algorithms and Data Structures.
// Author: Christoph Ullinger <[email protected]>

#ifndef QLEVER_STDEVEXPRESSION_H
#define QLEVER_STDEVEXPRESSION_H

#include <cmath>
#include <functional>
#include <memory>
#include <variant>

#include "engine/sparqlExpressions/AggregateExpression.h"
#include "engine/sparqlExpressions/LiteralExpression.h"
#include "engine/sparqlExpressions/NaryExpression.h"
#include "engine/sparqlExpressions/SparqlExpression.h"
#include "engine/sparqlExpressions/SparqlExpressionTypes.h"
#include "engine/sparqlExpressions/SparqlExpressionValueGetters.h"
#include "global/ValueId.h"

namespace sparqlExpression {

namespace detail {

/// The STDEV Expression

// Helper function to extract a double from a NumericValue variant
auto inline numValToDouble =
[]<typename T>(T&& value) -> std::optional<double> {
if constexpr (ad_utility::isSimilar<T, double> ||
ad_utility::isSimilar<T, int64_t>) {
return static_cast<double>(value);
} else {
return std::nullopt;

Check warning on line 34 in src/engine/sparqlExpressions/StdevExpression.h

View check run for this annotation

Codecov / codecov/patch

src/engine/sparqlExpressions/StdevExpression.h#L29-L34

Added lines #L29 - L34 were not covered by tests
}
};

Check warning on line 36 in src/engine/sparqlExpressions/StdevExpression.h

View check run for this annotation

Codecov / codecov/patch

src/engine/sparqlExpressions/StdevExpression.h#L36

Added line #L36 was not covered by tests

// Helper expression: The individual deviation squares. A DeviationExpression
// over X corresponds to the value (X - AVG(X))^2.
class DeviationExpression : public SparqlExpression {
private:
Ptr child_;
bool distinct_;

public:
DeviationExpression(bool distinct, Ptr&& child)
: child_{std::move(child)}, distinct_{distinct} {}

Check warning on line 47 in src/engine/sparqlExpressions/StdevExpression.h

View check run for this annotation

Codecov / codecov/patch

src/engine/sparqlExpressions/StdevExpression.h#L47

Added line #L47 was not covered by tests

// __________________________________________________________________________
ExpressionResult evaluate(EvaluationContext* context) const override {
auto impl =
[this, context](SingleExpressionResult auto&& el) -> ExpressionResult {

Check warning on line 52 in src/engine/sparqlExpressions/StdevExpression.h

View check run for this annotation

Codecov / codecov/patch

src/engine/sparqlExpressions/StdevExpression.h#L50-L52

Added lines #L50 - L52 were not covered by tests
// Prepare space for result
VectorWithMemoryLimit<IdOrLiteralOrIri> exprResult{context->_allocator};
std::fill_n(std::back_inserter(exprResult), context->size(),
IdOrLiteralOrIri{Id::makeUndefined()});

Check warning on line 56 in src/engine/sparqlExpressions/StdevExpression.h

View check run for this annotation

Codecov / codecov/patch

src/engine/sparqlExpressions/StdevExpression.h#L54-L56

Added lines #L54 - L56 were not covered by tests

auto devImpl = [&exprResult, context](auto generator) {
double sum = 0.0;
std::vector<double> childResults = {};

Check warning on line 60 in src/engine/sparqlExpressions/StdevExpression.h

View check run for this annotation

Codecov / codecov/patch

src/engine/sparqlExpressions/StdevExpression.h#L58-L60

Added lines #L58 - L60 were not covered by tests

// Collect values as doubles
for (auto& inp : generator) {
const auto& s = detail::NumericValueGetter{}(std::move(inp), context);
auto v = std::visit(numValToDouble, s);

Check warning on line 65 in src/engine/sparqlExpressions/StdevExpression.h

View check run for this annotation

Codecov / codecov/patch

src/engine/sparqlExpressions/StdevExpression.h#L64-L65

Added lines #L64 - L65 were not covered by tests
if (v.has_value()) {
childResults.push_back(v.value());
sum += v.value();

Check warning on line 68 in src/engine/sparqlExpressions/StdevExpression.h

View check run for this annotation

Codecov / codecov/patch

src/engine/sparqlExpressions/StdevExpression.h#L67-L68

Added lines #L67 - L68 were not covered by tests
}
context->cancellationHandle_->throwIfCancelled();

Check warning on line 70 in src/engine/sparqlExpressions/StdevExpression.h

View check run for this annotation

Codecov / codecov/patch

src/engine/sparqlExpressions/StdevExpression.h#L70

Added line #L70 was not covered by tests
}

// Calculate squared deviation and save for result
double avg = sum / static_cast<double>(context->size());

Check warning on line 74 in src/engine/sparqlExpressions/StdevExpression.h

View check run for this annotation

Codecov / codecov/patch

src/engine/sparqlExpressions/StdevExpression.h#L74

Added line #L74 was not covered by tests
for (size_t i = 0; i < childResults.size(); i++) {
exprResult.at(i) = IdOrLiteralOrIri{
ValueId::makeFromDouble(std::pow(childResults.at(i) - avg, 2))};

Check warning on line 77 in src/engine/sparqlExpressions/StdevExpression.h

View check run for this annotation

Codecov / codecov/patch

src/engine/sparqlExpressions/StdevExpression.h#L76-L77

Added lines #L76 - L77 were not covered by tests
}
};

Check warning on line 79 in src/engine/sparqlExpressions/StdevExpression.h

View check run for this annotation

Codecov / codecov/patch

src/engine/sparqlExpressions/StdevExpression.h#L79

Added line #L79 was not covered by tests

auto generator =
detail::makeGenerator(AD_FWD(el), context->size(), context);

Check warning on line 82 in src/engine/sparqlExpressions/StdevExpression.h

View check run for this annotation

Codecov / codecov/patch

src/engine/sparqlExpressions/StdevExpression.h#L81-L82

Added lines #L81 - L82 were not covered by tests
if (distinct_) {
context->cancellationHandle_->throwIfCancelled();
devImpl(detail::getUniqueElements(context, context->size(),
std::move(generator)));
} else {
devImpl(std::move(generator));

Check warning on line 88 in src/engine/sparqlExpressions/StdevExpression.h

View check run for this annotation

Codecov / codecov/patch

src/engine/sparqlExpressions/StdevExpression.h#L84-L88

Added lines #L84 - L88 were not covered by tests
}

return exprResult;
};
auto childRes = child_->evaluate(context);
return std::visit(impl, std::move(childRes));
};

Check warning on line 95 in src/engine/sparqlExpressions/StdevExpression.h

View check run for this annotation

Codecov / codecov/patch

src/engine/sparqlExpressions/StdevExpression.h#L91-L95

Added lines #L91 - L95 were not covered by tests

// __________________________________________________________________________
AggregateStatus isAggregate() const override {
return SparqlExpression::AggregateStatus::NoAggregate;

Check warning on line 99 in src/engine/sparqlExpressions/StdevExpression.h

View check run for this annotation

Codecov / codecov/patch

src/engine/sparqlExpressions/StdevExpression.h#L98-L99

Added lines #L98 - L99 were not covered by tests
}

// __________________________________________________________________________
[[nodiscard]] string getCacheKey(
const VariableToColumnMap& varColMap) const override {

Check warning on line 104 in src/engine/sparqlExpressions/StdevExpression.h

View check run for this annotation

Codecov / codecov/patch

src/engine/sparqlExpressions/StdevExpression.h#L104

Added line #L104 was not covered by tests
return absl::StrCat("[ SQ.DEVIATION ", distinct_ ? " DISTINCT " : "", "]",
child_->getCacheKey(varColMap));

Check warning on line 106 in src/engine/sparqlExpressions/StdevExpression.h

View check run for this annotation

Codecov / codecov/patch

src/engine/sparqlExpressions/StdevExpression.h#L106

Added line #L106 was not covered by tests
}

private:
// _________________________________________________________________________
std::span<SparqlExpression::Ptr> childrenImpl() override {
return {&child_, 1};

Check warning on line 112 in src/engine/sparqlExpressions/StdevExpression.h

View check run for this annotation

Codecov / codecov/patch

src/engine/sparqlExpressions/StdevExpression.h#L111-L112

Added lines #L111 - L112 were not covered by tests
}
};

// Separate subclass of AggregateOperation, that replaces its child with a
// DeviationExpression of this child. Everything else is left untouched.
template <typename AggregateOperation,
typename FinalOperation = decltype(identity)>
class DeviationAggExpression
: public AggregateExpression<AggregateOperation, FinalOperation> {
public:
// __________________________________________________________________________
DeviationAggExpression(bool distinct, SparqlExpression::Ptr&& child,
AggregateOperation aggregateOp = AggregateOperation{})
: AggregateExpression<AggregateOperation, FinalOperation>(
distinct,
std::make_unique<DeviationExpression>(distinct, std::move(child)),
aggregateOp){};

Check warning on line 129 in src/engine/sparqlExpressions/StdevExpression.h

View check run for this annotation

Codecov / codecov/patch

src/engine/sparqlExpressions/StdevExpression.h#L129

Added line #L129 was not covered by tests
};

// The final operation for dividing by degrees of freedom and calculation square
// root after summing up the squared deviation
inline auto stdevFinalOperation = [](const NumericValue& aggregation,
size_t numElements) {
auto divAndRoot = [](double value, double degreesOfFreedom) {

Check warning on line 136 in src/engine/sparqlExpressions/StdevExpression.h

View check run for this annotation

Codecov / codecov/patch

src/engine/sparqlExpressions/StdevExpression.h#L135-L136

Added lines #L135 - L136 were not covered by tests
if (degreesOfFreedom <= 0) {
return 0.0;
} else {
return std::sqrt(value / degreesOfFreedom);

Check warning on line 140 in src/engine/sparqlExpressions/StdevExpression.h

View check run for this annotation

Codecov / codecov/patch

src/engine/sparqlExpressions/StdevExpression.h#L138-L140

Added lines #L138 - L140 were not covered by tests
}
};
return makeNumericExpressionForAggregate<decltype(divAndRoot)>()(
aggregation, NumericValue{static_cast<double>(numElements) - 1});
};

Check warning on line 145 in src/engine/sparqlExpressions/StdevExpression.h

View check run for this annotation

Codecov / codecov/patch

src/engine/sparqlExpressions/StdevExpression.h#L142-L145

Added lines #L142 - L145 were not covered by tests

// The actual Standard Deviation Expression
// Mind the explicit instantiation of StdevExpressionBase in
// AggregateExpression.cpp
using StdevExpressionBase =
DeviationAggExpression<AvgOperation, decltype(stdevFinalOperation)>;
class StdevExpression : public StdevExpressionBase {
using StdevExpressionBase::StdevExpressionBase;
ValueId resultForEmptyGroup() const override { return Id::makeFromInt(0); }

Check warning on line 154 in src/engine/sparqlExpressions/StdevExpression.h

View check run for this annotation

Codecov / codecov/patch

src/engine/sparqlExpressions/StdevExpression.h#L154

Added line #L154 was not covered by tests
};

} // namespace detail

using detail::StdevExpression;

} // namespace sparqlExpression

#endif // QLEVER_STDEVEXPRESSION_H
3 changes: 3 additions & 0 deletions src/parser/sparqlParser/SparqlQleverVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "engine/sparqlExpressions/RegexExpression.h"
#include "engine/sparqlExpressions/RelationalExpressions.h"
#include "engine/sparqlExpressions/SampleExpression.h"
#include "engine/sparqlExpressions/StdevExpression.h"
#include "engine/sparqlExpressions/UuidExpressions.h"
#include "parser/RdfParser.h"
#include "parser/SparqlParser.h"
Expand Down Expand Up @@ -2227,6 +2228,8 @@ ExpressionPtr Visitor::visit(Parser::AggregateContext* ctx) {
}

return makePtr.operator()<GroupConcatExpression>(std::move(separator));
} else if (functionName == "stdev") {
return makePtr.operator()<StdevExpression>();

Check warning on line 2232 in src/parser/sparqlParser/SparqlQleverVisitor.cpp

View check run for this annotation

Codecov / codecov/patch

src/parser/sparqlParser/SparqlQleverVisitor.cpp#L2232

Added line #L2232 was not covered by tests
} else {
AD_CORRECTNESS_CHECK(functionName == "sample");
return makePtr.operator()<SampleExpression>();
Expand Down
1 change: 1 addition & 0 deletions src/parser/sparqlParser/SparqlQleverVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "engine/sparqlExpressions/AggregateExpression.h"
#include "engine/sparqlExpressions/NaryExpression.h"
#include "engine/sparqlExpressions/StdevExpression.h"
#include "parser/data/GraphRef.h"
#undef EOF
#include "parser/sparqlParser/generated/SparqlAutomaticVisitor.h"
Expand Down
2 changes: 2 additions & 0 deletions src/parser/sparqlParser/generated/SparqlAutomatic.g4
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@ aggregate : COUNT '(' DISTINCT? ( '*' | expression ) ')'
| MIN '(' DISTINCT? expression ')'
| MAX '(' DISTINCT? expression ')'
| AVG '(' DISTINCT? expression ')'
| STDEV '(' DISTINCT? expression ')'
| SAMPLE '(' DISTINCT? expression ')'
| GROUP_CONCAT '(' DISTINCT? expression ( ';' SEPARATOR '=' string )? ')' ;

Expand Down Expand Up @@ -763,6 +764,7 @@ SUM : S U M;
MIN : M I N;
MAX : M A X;
AVG : A V G;
STDEV : S T D E V ;
SAMPLE : S A M P L E;
SEPARATOR : S E P A R A T O R;

Expand Down
4 changes: 3 additions & 1 deletion src/parser/sparqlParser/generated/SparqlAutomatic.interp

Large diffs are not rendered by default.

75 changes: 38 additions & 37 deletions src/parser/sparqlParser/generated/SparqlAutomatic.tokens
Original file line number Diff line number Diff line change
Expand Up @@ -136,43 +136,44 @@ SUM=135
MIN=136
MAX=137
AVG=138
SAMPLE=139
SEPARATOR=140
IRI_REF=141
PNAME_NS=142
PNAME_LN=143
BLANK_NODE_LABEL=144
VAR1=145
VAR2=146
LANGTAG=147
PREFIX_LANGTAG=148
INTEGER=149
DECIMAL=150
DOUBLE=151
INTEGER_POSITIVE=152
DECIMAL_POSITIVE=153
DOUBLE_POSITIVE=154
INTEGER_NEGATIVE=155
DECIMAL_NEGATIVE=156
DOUBLE_NEGATIVE=157
EXPONENT=158
STRING_LITERAL1=159
STRING_LITERAL2=160
STRING_LITERAL_LONG1=161
STRING_LITERAL_LONG2=162
ECHAR=163
NIL=164
ANON=165
PN_CHARS_U=166
VARNAME=167
PN_PREFIX=168
PN_LOCAL=169
PLX=170
PERCENT=171
HEX=172
PN_LOCAL_ESC=173
WS=174
COMMENTS=175
STDEV=139
SAMPLE=140
SEPARATOR=141
IRI_REF=142
PNAME_NS=143
PNAME_LN=144
BLANK_NODE_LABEL=145
VAR1=146
VAR2=147
LANGTAG=148
PREFIX_LANGTAG=149
INTEGER=150
DECIMAL=151
DOUBLE=152
INTEGER_POSITIVE=153
DECIMAL_POSITIVE=154
DOUBLE_POSITIVE=155
INTEGER_NEGATIVE=156
DECIMAL_NEGATIVE=157
DOUBLE_NEGATIVE=158
EXPONENT=159
STRING_LITERAL1=160
STRING_LITERAL2=161
STRING_LITERAL_LONG1=162
STRING_LITERAL_LONG2=163
ECHAR=164
NIL=165
ANON=166
PN_CHARS_U=167
VARNAME=168
PN_PREFIX=169
PN_LOCAL=170
PLX=171
PERCENT=172
HEX=173
PN_LOCAL_ESC=174
WS=175
COMMENTS=176
'*'=1
'('=2
')'=3
Expand Down
Loading

0 comments on commit fc56318

Please sign in to comment.