Skip to content

Commit

Permalink
Add equality comparison support for bool. (#4701)
Browse files Browse the repository at this point in the history
  • Loading branch information
zygoloid authored Dec 18, 2024
1 parent 4d0a6db commit c1590f8
Show file tree
Hide file tree
Showing 10 changed files with 980 additions and 8 deletions.
8 changes: 8 additions & 0 deletions core/prelude/operators/comparison.carbon
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,11 @@ interface Ordered {
fn Greater[self: Self](other: Self) -> bool;
fn GreaterOrEquivalent[self: Self](other: Self) -> bool;
}

// Equality comparison for `bool`.
// Note that this must be provided in this library as `bool` doesn't have any
// associated libraries of its own.
impl bool as Eq {
fn Equal[self: Self](other: Self) -> bool = "bool.eq";
fn NotEqual[self: Self](other: Self) -> bool = "bool.neq";
}
22 changes: 22 additions & 0 deletions toolchain/check/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,18 @@ static auto PerformBuiltinFloatComparison(
return MakeBoolResult(context, bool_type_id, result);
}

// Performs a builtin boolean comparison.
static auto PerformBuiltinBoolComparison(
Context& context, SemIR::BuiltinFunctionKind builtin_kind,
SemIR::InstId lhs_id, SemIR::InstId rhs_id, SemIR::TypeId bool_type_id) {
bool lhs = context.insts().GetAs<SemIR::BoolLiteral>(lhs_id).value.ToBool();
bool rhs = context.insts().GetAs<SemIR::BoolLiteral>(rhs_id).value.ToBool();
return MakeBoolResult(context, bool_type_id,
builtin_kind == SemIR::BuiltinFunctionKind::BoolEq
? lhs == rhs
: lhs != rhs);
}

// Returns a constant for a call to a builtin function.
static auto MakeConstantForBuiltinCall(Context& context, SemIRLoc loc,
SemIR::Call call,
Expand Down Expand Up @@ -1235,6 +1247,16 @@ static auto MakeConstantForBuiltinCall(Context& context, SemIRLoc loc,
return PerformBuiltinFloatComparison(context, builtin_kind, arg_ids[0],
arg_ids[1], call.type_id);
}

// Bool comparisons.
case SemIR::BuiltinFunctionKind::BoolEq:
case SemIR::BuiltinFunctionKind::BoolNeq: {
if (phase != Phase::Template) {
break;
}
return PerformBuiltinBoolComparison(context, builtin_kind, arg_ids[0],
arg_ids[1], call.type_id);
}
}

return SemIR::ConstantId::NotConstant;
Expand Down
428 changes: 428 additions & 0 deletions toolchain/check/testdata/builtins/bool/eq.carbon

Large diffs are not rendered by default.

428 changes: 428 additions & 0 deletions toolchain/check/testdata/builtins/bool/neq.carbon

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions toolchain/check/testdata/operators/overloaded/eq.carbon
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ fn TestLhsBad(a: D, b: C) -> bool {
// CHECK:STDOUT: }
// CHECK:STDOUT: %Core.import = import Core
// CHECK:STDOUT: %C.decl: type = class_decl @C [template = constants.%C] {} {}
// CHECK:STDOUT: impl_decl @impl [template] {} {
// CHECK:STDOUT: impl_decl @impl.1 [template] {} {
// CHECK:STDOUT: %C.ref: type = name_ref C, file.%C.decl [template = constants.%C]
// CHECK:STDOUT: %Core.ref: <namespace> = name_ref Core, imports.%Core [template = imports.%Core]
// CHECK:STDOUT: %Eq.ref: type = name_ref Eq, imports.%import_ref.1 [template = constants.%Eq.type]
Expand Down Expand Up @@ -170,7 +170,7 @@ fn TestLhsBad(a: D, b: C) -> bool {
// CHECK:STDOUT: }
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: impl @impl: %C.ref as %Eq.ref {
// CHECK:STDOUT: impl @impl.1: %C.ref as %Eq.ref {
// CHECK:STDOUT: %Equal.decl: %Equal.type.1 = fn_decl @Equal.1 [template = constants.%Equal.1] {
// CHECK:STDOUT: %self.patt: %C = binding_pattern self
// CHECK:STDOUT: %self.param_patt: %C = value_param_pattern %self.patt, runtime_param0
Expand Down Expand Up @@ -377,7 +377,7 @@ fn TestLhsBad(a: D, b: C) -> bool {
// CHECK:STDOUT: %Core: <namespace> = namespace file.%Core.import, [template] {
// CHECK:STDOUT: .Eq = %import_ref.1
// CHECK:STDOUT: .Bool = %import_ref.7
// CHECK:STDOUT: .ImplicitAs = %import_ref.9
// CHECK:STDOUT: .ImplicitAs = %import_ref.12
// CHECK:STDOUT: import Core//prelude
// CHECK:STDOUT: import Core//prelude/...
// CHECK:STDOUT: }
Expand All @@ -395,7 +395,7 @@ fn TestLhsBad(a: D, b: C) -> bool {
// CHECK:STDOUT: %Core.import = import Core
// CHECK:STDOUT: %C.decl: type = class_decl @C [template = constants.%C] {} {}
// CHECK:STDOUT: %D.decl: type = class_decl @D [template = constants.%D] {} {}
// CHECK:STDOUT: impl_decl @impl [template] {} {
// CHECK:STDOUT: impl_decl @impl.1 [template] {} {
// CHECK:STDOUT: %C.ref: type = name_ref C, file.%C.decl [template = constants.%C]
// CHECK:STDOUT: %Core.ref: <namespace> = name_ref Core, imports.%Core [template = imports.%Core]
// CHECK:STDOUT: %Eq.ref: type = name_ref Eq, imports.%import_ref.1 [template = constants.%Eq.type]
Expand Down Expand Up @@ -442,7 +442,7 @@ fn TestLhsBad(a: D, b: C) -> bool {
// CHECK:STDOUT: }
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: impl @impl: %C.ref as %Eq.ref {
// CHECK:STDOUT: impl @impl.1: %C.ref as %Eq.ref {
// CHECK:STDOUT: %Equal.decl: %Equal.type.1 = fn_decl @Equal.1 [template = constants.%Equal.1] {
// CHECK:STDOUT: %self.patt: %C = binding_pattern self
// CHECK:STDOUT: %self.param_patt: %C = value_param_pattern %self.patt, runtime_param0
Expand Down
4 changes: 2 additions & 2 deletions toolchain/check/testdata/operators/overloaded/ordered.carbon
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ fn TestGreaterEqual(a: D, b: D) -> bool {
// CHECK:STDOUT: }
// CHECK:STDOUT: %Core.import = import Core
// CHECK:STDOUT: %C.decl: type = class_decl @C [template = constants.%C] {} {}
// CHECK:STDOUT: impl_decl @impl [template] {} {
// CHECK:STDOUT: impl_decl @impl.1 [template] {} {
// CHECK:STDOUT: %C.ref: type = name_ref C, file.%C.decl [template = constants.%C]
// CHECK:STDOUT: %Core.ref: <namespace> = name_ref Core, imports.%Core [template = imports.%Core]
// CHECK:STDOUT: %Ordered.ref: type = name_ref Ordered, imports.%import_ref.1 [template = constants.%Ordered.type]
Expand Down Expand Up @@ -214,7 +214,7 @@ fn TestGreaterEqual(a: D, b: D) -> bool {
// CHECK:STDOUT: }
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: impl @impl: %C.ref as %Ordered.ref {
// CHECK:STDOUT: impl @impl.1: %C.ref as %Ordered.ref {
// CHECK:STDOUT: %Less.decl: %Less.type.1 = fn_decl @Less.1 [template = constants.%Less.1] {
// CHECK:STDOUT: %self.patt: %C = binding_pattern self
// CHECK:STDOUT: %self.param_patt: %C = value_param_pattern %self.patt, runtime_param0
Expand Down
6 changes: 5 additions & 1 deletion toolchain/lower/handle_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ static auto GetBuiltinICmpPredicate(SemIR::BuiltinFunctionKind builtin_kind,
-> llvm::CmpInst::Predicate {
switch (builtin_kind) {
case SemIR::BuiltinFunctionKind::IntEq:
case SemIR::BuiltinFunctionKind::BoolEq:
return llvm::CmpInst::ICMP_EQ;
case SemIR::BuiltinFunctionKind::IntNeq:
case SemIR::BuiltinFunctionKind::BoolNeq:
return llvm::CmpInst::ICMP_NE;
case SemIR::BuiltinFunctionKind::IntLess:
return is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT;
Expand Down Expand Up @@ -263,7 +265,9 @@ static auto HandleBuiltinCall(FunctionContext& context, SemIR::InstId inst_id,
case SemIR::BuiltinFunctionKind::IntLess:
case SemIR::BuiltinFunctionKind::IntLessEq:
case SemIR::BuiltinFunctionKind::IntGreater:
case SemIR::BuiltinFunctionKind::IntGreaterEq: {
case SemIR::BuiltinFunctionKind::IntGreaterEq:
case SemIR::BuiltinFunctionKind::BoolEq:
case SemIR::BuiltinFunctionKind::BoolNeq: {
context.SetLocal(
inst_id,
context.builder().CreateICmp(
Expand Down
70 changes: 70 additions & 0 deletions toolchain/lower/testdata/builtins/no_prelude/bool.carbon
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
// Exceptions. See /LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// AUTOUPDATE
// TIP: To test this file alone, run:
// TIP: bazel test //toolchain/testing:file_test --test_arg=--file_tests=toolchain/lower/testdata/builtins/no_prelude/bool.carbon
// TIP: To dump output, run:
// TIP: bazel run //toolchain/testing:file_test -- --dump_output --file_tests=toolchain/lower/testdata/builtins/no_prelude/bool.carbon

fn Bool() -> type = "bool.make_type";

fn Eq(a: Bool(), b: Bool()) -> Bool() = "bool.eq";
fn TestEq(a: Bool(), b: Bool()) -> Bool() { return Eq(a, b); }

fn Neq(a: Bool(), b: Bool()) -> Bool() = "bool.neq";
fn TestNeq(a: Bool(), b: Bool()) -> Bool() { return Neq(a, b); }

fn IfEq(a: Bool(), b: Bool()) -> Bool() {
if (Eq(a, b)) { return true; }
return false;
}

// CHECK:STDOUT: ; ModuleID = 'bool.carbon'
// CHECK:STDOUT: source_filename = "bool.carbon"
// CHECK:STDOUT:
// CHECK:STDOUT: define i1 @_CTestEq.Main(i1 %a, i1 %b) !dbg !4 {
// CHECK:STDOUT: entry:
// CHECK:STDOUT: %bool.eq = icmp eq i1 %a, %b, !dbg !7
// CHECK:STDOUT: ret i1 %bool.eq, !dbg !8
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: define i1 @_CTestNeq.Main(i1 %a, i1 %b) !dbg !9 {
// CHECK:STDOUT: entry:
// CHECK:STDOUT: %bool.neq = icmp ne i1 %a, %b, !dbg !10
// CHECK:STDOUT: ret i1 %bool.neq, !dbg !11
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: define i1 @_CIfEq.Main(i1 %a, i1 %b) !dbg !12 {
// CHECK:STDOUT: entry:
// CHECK:STDOUT: %bool.eq = icmp eq i1 %a, %b, !dbg !13
// CHECK:STDOUT: br i1 %bool.eq, label %if.then, label %if.else, !dbg !14
// CHECK:STDOUT:
// CHECK:STDOUT: if.then: ; preds = %entry
// CHECK:STDOUT: ret i1 true, !dbg !15
// CHECK:STDOUT:
// CHECK:STDOUT: if.else: ; preds = %entry
// CHECK:STDOUT: ret i1 false, !dbg !16
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: !llvm.module.flags = !{!0, !1}
// CHECK:STDOUT: !llvm.dbg.cu = !{!2}
// CHECK:STDOUT:
// CHECK:STDOUT: !0 = !{i32 7, !"Dwarf Version", i32 5}
// CHECK:STDOUT: !1 = !{i32 2, !"Debug Info Version", i32 3}
// CHECK:STDOUT: !2 = distinct !DICompileUnit(language: DW_LANG_C, file: !3, producer: "carbon", isOptimized: false, runtimeVersion: 0, emissionKind: FullDebug)
// CHECK:STDOUT: !3 = !DIFile(filename: "bool.carbon", directory: "")
// CHECK:STDOUT: !4 = distinct !DISubprogram(name: "TestEq", linkageName: "_CTestEq.Main", scope: null, file: !3, line: 14, type: !5, spFlags: DISPFlagDefinition, unit: !2)
// CHECK:STDOUT: !5 = !DISubroutineType(types: !6)
// CHECK:STDOUT: !6 = !{}
// CHECK:STDOUT: !7 = !DILocation(line: 14, column: 52, scope: !4)
// CHECK:STDOUT: !8 = !DILocation(line: 14, column: 45, scope: !4)
// CHECK:STDOUT: !9 = distinct !DISubprogram(name: "TestNeq", linkageName: "_CTestNeq.Main", scope: null, file: !3, line: 17, type: !5, spFlags: DISPFlagDefinition, unit: !2)
// CHECK:STDOUT: !10 = !DILocation(line: 17, column: 53, scope: !9)
// CHECK:STDOUT: !11 = !DILocation(line: 17, column: 46, scope: !9)
// CHECK:STDOUT: !12 = distinct !DISubprogram(name: "IfEq", linkageName: "_CIfEq.Main", scope: null, file: !3, line: 19, type: !5, spFlags: DISPFlagDefinition, unit: !2)
// CHECK:STDOUT: !13 = !DILocation(line: 20, column: 7, scope: !12)
// CHECK:STDOUT: !14 = !DILocation(line: 20, column: 6, scope: !12)
// CHECK:STDOUT: !15 = !DILocation(line: 20, column: 19, scope: !12)
// CHECK:STDOUT: !16 = !DILocation(line: 21, column: 3, scope: !12)
8 changes: 8 additions & 0 deletions toolchain/sem_ir/builtin_function_kind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,14 @@ constexpr BuiltinInfo FloatGreater = {
constexpr BuiltinInfo FloatGreaterEq = {
"float.greater_eq", ValidateSignature<auto(FloatT, FloatT)->Bool>};

// "bool.eq": bool equality comparison.
constexpr BuiltinInfo BoolEq = {"bool.eq",
ValidateSignature<auto(Bool, Bool)->Bool>};

// "bool.neq": bool non-equality comparison.
constexpr BuiltinInfo BoolNeq = {"bool.neq",
ValidateSignature<auto(Bool, Bool)->Bool>};

} // namespace BuiltinFunctionInfo

CARBON_DEFINE_ENUM_CLASS_NAMES(BuiltinFunctionKind) = {
Expand Down
4 changes: 4 additions & 0 deletions toolchain/sem_ir/builtin_function_kind.def
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,8 @@ CARBON_SEM_IR_BUILTIN_FUNCTION_KIND(FloatLessEq)
CARBON_SEM_IR_BUILTIN_FUNCTION_KIND(FloatGreater)
CARBON_SEM_IR_BUILTIN_FUNCTION_KIND(FloatGreaterEq)

// Bool comparison.
CARBON_SEM_IR_BUILTIN_FUNCTION_KIND(BoolEq)
CARBON_SEM_IR_BUILTIN_FUNCTION_KIND(BoolNeq)

#undef CARBON_SEM_IR_BUILTIN_FUNCTION_KIND

0 comments on commit c1590f8

Please sign in to comment.