Skip to content

Commit

Permalink
Add support for shared field constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
jchadwick-buf committed Sep 4, 2024
1 parent a01c79d commit 6dc455a
Show file tree
Hide file tree
Showing 15 changed files with 229 additions and 43 deletions.
5 changes: 2 additions & 3 deletions buf/validate/conformance/runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
namespace buf::validate::conformance {

harness::TestConformanceResponse TestRunner::runTest(
const harness::TestConformanceRequest& request,
const google::protobuf::DescriptorPool* descriptorPool) {
const harness::TestConformanceRequest& request) {
harness::TestConformanceResponse response;
for (const auto& tc : request.cases()) {
auto& result = response.mutable_results()->operator[](tc.first);
Expand All @@ -32,7 +31,7 @@ harness::TestConformanceResponse TestRunner::runTest(
*result.mutable_unexpected_error() = "could not parse type url " + dyn.type_url();
continue;
}
const auto* desc = descriptorPool->FindMessageTypeByName(dyn.type_url().substr(pos + 1));
const auto* desc = descriptorPool_->FindMessageTypeByName(dyn.type_url().substr(pos + 1));
if (desc == nullptr) {
*result.mutable_unexpected_error() = "could not find descriptor for type " + dyn.type_url();
} else {
Expand Down
10 changes: 6 additions & 4 deletions buf/validate/conformance/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,19 @@ namespace buf::validate::conformance {

class TestRunner {
public:
explicit TestRunner() : validatorFactory_(ValidatorFactory::New().value()) {}
explicit TestRunner(const google::protobuf::DescriptorPool* descriptorPool)
: descriptorPool_(descriptorPool), validatorFactory_(ValidatorFactory::New().value()) {
validatorFactory_->SetMessageFactory(&messageFactory_, descriptorPool_);
}

harness::TestConformanceResponse runTest(
const harness::TestConformanceRequest& request,
const google::protobuf::DescriptorPool* descriptorPool);
harness::TestConformanceResponse runTest(const harness::TestConformanceRequest& request);
harness::TestResult runTestCase(
const google::protobuf::Descriptor* desc, const google::protobuf::Any& dyn);
harness::TestResult runTestCase(const google::protobuf::Message& message);

private:
google::protobuf::DynamicMessageFactory messageFactory_;
const google::protobuf::DescriptorPool* descriptorPool_;
std::unique_ptr<ValidatorFactory> validatorFactory_;
google::protobuf::Arena arena_;
};
Expand Down
7 changes: 4 additions & 3 deletions buf/validate/conformance/runner_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
#include "buf/validate/conformance/runner.h"

int main(int argc, char** argv) {
google::protobuf::DescriptorPool descriptorPool;
buf::validate::conformance::TestRunner runner;
google::protobuf::DescriptorPool descriptorPool{
google::protobuf::DescriptorPool::generated_pool()};
buf::validate::conformance::harness::TestConformanceRequest request;
request.ParseFromIstream(&std::cin);
for (const auto& file : request.fdset().file()) {
descriptorPool.BuildFile(file);
}
auto response = runner.runTest(request, &descriptorPool);
buf::validate::conformance::TestRunner runner{&descriptorPool};
auto response = runner.runTest(request);
response.SerializeToOstream(&std::cout);
return 0;
}
15 changes: 15 additions & 0 deletions buf/validate/internal/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ cc_library(
"@com_google_cel_cpp//eval/public:activation",
"@com_google_cel_cpp//eval/public:cel_expression",
"@com_google_cel_cpp//eval/public/structs:cel_proto_wrapper",
"@com_google_cel_cpp//eval/public/containers:field_access",
"@com_google_cel_cpp//eval/public/containers:field_backed_list_impl",
"@com_google_cel_cpp//eval/public/containers:field_backed_map_impl",
"@com_google_cel_cpp//parser",
"@com_google_cel_cpp//base:value"
],
)

Expand All @@ -44,15 +48,26 @@ cc_library(
deps = [
"@com_google_absl//absl/status",
"@com_google_protobuf//:protobuf",
":message_factory",
],
)

cc_library(
name = "message_factory",
srcs = ["message_factory.cc"],
hdrs = ["message_factory.h"],
deps = [
"@com_google_protobuf//:protobuf",
]
)

cc_library(
name = "message_rules",
srcs = ["message_rules.cc"],
hdrs = ["message_rules.h"],
deps = [
":field_rules",
":message_factory",
"@com_github_bufbuild_protovalidate//proto/protovalidate/buf/validate:validate_proto_cc",
"@com_google_absl//absl/status",
"@com_google_cel_cpp//eval/public:cel_expression",
Expand Down
40 changes: 36 additions & 4 deletions buf/validate/internal/cel_constraint_rules.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

#include "buf/validate/internal/cel_constraint_rules.h"

#include "base/values/struct_value.h"
#include "eval/public/containers/field_access.h"
#include "eval/public/containers/field_backed_list_impl.h"
#include "eval/public/containers/field_backed_map_impl.h"
#include "eval/public/structs/cel_proto_wrapper.h"
#include "parser/parser.h"

Expand Down Expand Up @@ -57,10 +61,31 @@ absl::Status ProcessConstraint(
return absl::OkStatus();
}

cel::runtime::CelValue ProtoFieldToCelValue(
const google::protobuf::Message* message,
const google::protobuf::FieldDescriptor* field,
google::protobuf::Arena* arena) {
if (field->is_map()) {
return cel::runtime::CelValue::CreateMap(
google::protobuf::Arena::Create<cel::runtime::FieldBackedMapImpl>(
arena, message, field, arena));
} else if (field->is_repeated()) {
return cel::runtime::CelValue::CreateList(
google::protobuf::Arena::Create<cel::runtime::FieldBackedListImpl>(
arena, message, field, arena));
} else if (cel::runtime::CelValue result;
cel::runtime::CreateValueFromSingleField(message, field, arena, &result).ok()) {
return result;
}
return cel::runtime::CelValue::CreateNull();
}

} // namespace

absl::Status CelConstraintRules::Add(
google::api::expr::runtime::CelExpressionBuilder& builder, Constraint constraint) {
google::api::expr::runtime::CelExpressionBuilder& builder,
Constraint constraint,
const google::protobuf::FieldDescriptor* rule) {
auto pexpr_or = cel::parser::Parse(constraint.expression());
if (!pexpr_or.ok()) {
return pexpr_or.status();
Expand All @@ -71,20 +96,21 @@ absl::Status CelConstraintRules::Add(
return expr_or.status();
}
std::unique_ptr<cel::runtime::CelExpression> expr = std::move(expr_or).value();
exprs_.emplace_back(CompiledConstraint{std::move(constraint), std::move(expr)});
exprs_.emplace_back(CompiledConstraint{std::move(constraint), std::move(expr), rule});
return absl::OkStatus();
}

absl::Status CelConstraintRules::Add(
google::api::expr::runtime::CelExpressionBuilder& builder,
std::string_view id,
std::string_view message,
std::string_view expression) {
std::string_view expression,
const google::protobuf::FieldDescriptor* rule) {
Constraint constraint;
*constraint.mutable_id() = id;
*constraint.mutable_message() = message;
*constraint.mutable_expression() = expression;
return Add(builder, constraint);
return Add(builder, constraint, rule);
}

absl::Status CelConstraintRules::ValidateCel(
Expand All @@ -94,11 +120,17 @@ absl::Status CelConstraintRules::ValidateCel(
activation.InsertValue("rules", rules_);
activation.InsertValue("now", cel::runtime::CelValue::CreateTimestamp(absl::Now()));
absl::Status status = absl::OkStatus();

for (const auto& expr : exprs_) {
if (rules_.IsMessage() && expr.rule) {
activation.InsertValue(
"rule", ProtoFieldToCelValue(rules_.MessageOrDie(), expr.rule, ctx.arena));
}
status = ProcessConstraint(ctx, fieldName, activation, expr);
if (ctx.shouldReturn(status)) {
break;
}
activation.RemoveValueEntry("rule");
}
activation.RemoveValueEntry("rules");
return status;
Expand Down
8 changes: 6 additions & 2 deletions buf/validate/internal/cel_constraint_rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ namespace buf::validate::internal {
struct CompiledConstraint {
buf::validate::Constraint constraint;
std::unique_ptr<google::api::expr::runtime::CelExpression> expr;
const google::protobuf::FieldDescriptor* rule;
};

// An abstract base class for constraint with rules that are compiled into CEL expressions.
Expand All @@ -38,12 +39,15 @@ class CelConstraintRules : public ConstraintRules {
using Base::Base;

absl::Status Add(
google::api::expr::runtime::CelExpressionBuilder& builder, Constraint constraint);
google::api::expr::runtime::CelExpressionBuilder& builder,
Constraint constraint,
const google::protobuf::FieldDescriptor* rule);
absl::Status Add(
google::api::expr::runtime::CelExpressionBuilder& builder,
std::string_view id,
std::string_view message,
std::string_view expression);
std::string_view expression,
const google::protobuf::FieldDescriptor* rule);
[[nodiscard]] const std::vector<CompiledConstraint>& getExprs() const { return exprs_; }

// Validate all the cel rules given the activation that already has 'this' bound.
Expand Down
25 changes: 21 additions & 4 deletions buf/validate/internal/cel_rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "absl/status/status.h"
#include "buf/validate/internal/cel_constraint_rules.h"
#include "buf/validate/internal/message_factory.h"
#include "buf/validate/validate.pb.h"
#include "google/protobuf/arena.h"
#include "google/protobuf/descriptor.h"
Expand All @@ -24,22 +25,38 @@ namespace buf::validate::internal {

template <typename R>
absl::Status BuildCelRules(
std::unique_ptr<MessageFactory>& messageFactory,
google::protobuf::Arena* arena,
google::api::expr::runtime::CelExpressionBuilder& builder,
const R& rules,
CelConstraintRules& result) {
result.setRules(&rules, arena);
// Look for constraints on the set fields.
std::vector<const google::protobuf::FieldDescriptor*> fields;
R::GetReflection()->ListFields(rules, &fields);
google::protobuf::Message* reparsedRules;
if (messageFactory) {
reparsedRules = messageFactory->messageFactory()
->GetPrototype(messageFactory->descriptorPool()->FindMessageTypeByName(
rules.GetTypeName()))
->New(arena);
if (!Reparse(*messageFactory, rules, reparsedRules)) {
reparsedRules = nullptr;
}
}
if (reparsedRules) {
result.setRules(reparsedRules, arena);
reparsedRules->GetReflection()->ListFields(*reparsedRules, &fields);
} else {
result.setRules(&rules, arena);
R::GetReflection()->ListFields(rules, &fields);
}
for (const auto* field : fields) {
if (!field->options().HasExtension(buf::validate::priv::field)) {
continue;
}
const auto& fieldLvl = field->options().GetExtension(buf::validate::priv::field);
for (const auto& constraint : fieldLvl.cel()) {
auto status =
result.Add(builder, constraint.id(), constraint.message(), constraint.expression());
auto status = result.Add(
builder, constraint.id(), constraint.message(), constraint.expression(), field);
if (!status.ok()) {
return status;
}
Expand Down
Loading

0 comments on commit 6dc455a

Please sign in to comment.