Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
1) Flag the use of `sum(&)` with non-additive types at run-time.
2) Flag uses of `sum()` with strings with a compile-time warning.
   Eventually, change this warning into an error.
  • Loading branch information
rvprasad committed Jan 11, 2025
1 parent b499bb2 commit 235f1bf
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 37 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ deref_symlinks ?= ## Deference symbolic links for `make install`
O := .build
SOURCES := $(shell find src -name '*.cr')
SPEC_SOURCES := $(shell find spec -name '*.cr')
override FLAGS += --error-trace -D strict_multi_assign -D preview_overload_order $(if $(release),--release )$(if $(stats),--stats )$(if $(progress),--progress )$(if $(threads),--threads $(threads) )$(if $(debug),-d )$(if $(static),--static )$(if $(LDFLAGS),--link-flags="$(LDFLAGS)" )$(if $(target),--cross-compile --target $(target) )$(if $(interpreter),,-Dwithout_interpreter )
override FLAGS += -D strict_multi_assign -D preview_overload_order $(if $(release),--release )$(if $(stats),--stats )$(if $(progress),--progress )$(if $(threads),--threads $(threads) )$(if $(debug),-d )$(if $(static),--static )$(if $(LDFLAGS),--link-flags="$(LDFLAGS)" )$(if $(target),--cross-compile --target $(target) )$(if $(interpreter),,-Dwithout_interpreter )
SPEC_WARNINGS_OFF := --exclude-warnings spec/std --exclude-warnings spec/compiler --exclude-warnings spec/primitives
override SPEC_FLAGS += $(if $(verbose),-v )$(if $(junit_output),--junit_output $(junit_output) )$(if $(order),--order=$(order) )
CRYSTAL_CONFIG_LIBRARY_PATH := '$$ORIGIN/../lib/crystal'
Expand Down
35 changes: 14 additions & 21 deletions spec/std/enumerable_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -1369,7 +1369,7 @@ describe "Enumerable" do
it { [1, 3].sum(0_u64).should eq(4_u64) }
it { [1, 10000000000_u64].sum(0_u64).should eq(10000000001) }
it "raises if union types are summed", tags: %w[slow] do
exc = assert_error <<-CRYSTAL,
assert_error <<-CRYSTAL,
require "prelude"
[1, 10000000000_u64].sum
CRYSTAL
Expand All @@ -1384,28 +1384,21 @@ describe "Enumerable" do
it { [1, 2].sum { |x| 2_u64 * x }.should eq(6_u64) }
it { [1, 2].sum(1) { |x| 2_u64 * x }.should eq(7_i32) }

it { {"a", "b", 3}.sum(&.to_s).should eq("ab3") }
it { ["a", "b", 3].sum(&.to_s).should eq("ab3") }
it { [1, 2, 3].sum(&.to_s).should eq("123") }
it { ["a", "b", "c"].sum(&.length).should eq(3) }
it "raises if enumerable of different types are summed", tags: %w[slow] do
exc = assert_error <<-CRYSTAL,
it { ["a", "b", "c"].sum(&.size).should eq(3) }

it "raises if non-additive values are summed", tags: %w[slow] do
assert_warning <<-CRYSTAL,
require "prelude"
["a", "b", 3].sum()
["a", "b", "c"].sum
CRYSTAL
"`Enumerable#sum()` and `#product()` do not support heterogeneous " +
"enumerables. Instead, use `Enumerable#sum(&)` and `#product(&)`, " +
"respectively, with an appropriate transformation block."
"`Enumerable#sum` does not support non-additive types. " +
"To join an enumerable of strings, use `Enumerable#join`."
end

it "raises if enumerable of union types are summed", tags: %w[slow] do
exc = assert_error <<-CRYSTAL,
require "prelude"
["a", "b", 3].sum("")
CRYSTAL
"`Enumerable#sum()` and `#product()` do not support heterogeneous " +
"enumerables. Instead, use `Enumerable#sum(&)` and `#product(&)`, " +
"respectively, with an appropriate transformation block."
it "raises if non-additive values are summed", tags: %w[slow] do
expect_raises ArgumentError, "`Enumerable#sum` does not support non-additive types." do
[1, 2, 3].sum(&.to_s)
end
end

it "uses additive_identity from type" do
Expand Down Expand Up @@ -1453,7 +1446,7 @@ describe "Enumerable" do
it { [1, 3].product(3_u64).should eq(9_u64) }
it { [1, 10000000000_u64].product(3_u64).should eq(30000000000_u64) }
it "raises if union types are multiplied", tags: %w[slow] do
exc = assert_error <<-CRYSTAL,
assert_error <<-CRYSTAL,
require "prelude"
[1, 10000000000_u64].product
CRYSTAL
Expand All @@ -1463,7 +1456,7 @@ describe "Enumerable" do
"of the intended type of the call."
end
it { [1, 2].product { |x| 2_u64 * x }.should eq(8_u64) }
it { [1, 2].product(2) { |x| 2_u64 * x}.should eq(16_i32) }
it { [1, 2].product(2) { |x| 2_u64 * x }.should eq(16_i32) }
it { [1, 2].product { |x| 2_u64 * x }.should eq(8_u64) }
it { [1, 2].product(2) { |x| 2_u64 * x }.should eq(16_i32) }
end
Expand Down
24 changes: 9 additions & 15 deletions src/enumerable.cr
Original file line number Diff line number Diff line change
Expand Up @@ -1757,7 +1757,10 @@ module Enumerable(T)
# ```
def sum
{% if T == String %}
# optimize for string
{{
warning("`Enumerable#sum` does not support non-additive types. " +
"To join an enumerable of strings, use `Enumerable#join`.")
}}
join
{% elsif T < Array %}
# optimize for array
Expand All @@ -1775,7 +1778,8 @@ module Enumerable(T)
if type.responds_to? :additive_identity
type.additive_identity
else
type.zero
raise ArgumentError.new("`Enumerable#sum` does not support " +
"non-additive types.")
end
end

Expand Down Expand Up @@ -1818,17 +1822,8 @@ module Enumerable(T)
# ```
# ([] of Int32).sum { |x| x + 1 } # => 0
# ```
def sum(& : T ->)
reflect = Reflect(typeof(yield Enumerable.element_type(self)))
if reflect.type == String
sum("") do |value|
yield value
end
else
sum(additive_identity(reflect)) do |value|
yield value
end
end
def sum(&block : T -> _)
sum(additive_identity(Reflect(typeof(yield Enumerable.element_type(self)))), &block)
end

# Adds *initial* and all results of the passed block for each element in the collection.
Expand All @@ -1844,7 +1839,7 @@ module Enumerable(T)
# ```
# ([] of String).sum(1) { |name| name.size } # => 1
# ```
def sum(initial, & : T ->)
def sum(initial, & : T -> _)
reduce(initial) { |memo, e| memo + (yield e) }
end

Expand Down Expand Up @@ -2303,7 +2298,6 @@ module Enumerable(T)
# For now, Reflect is used to reject union types in `#sum()` and
# `#product()` methods.
def self.type
{{ p!(X) }}
{% if X.union? %}
{{
raise("`Enumerable#sum()` and `#product()` do not support Union " +
Expand Down

0 comments on commit 235f1bf

Please sign in to comment.