diff --git a/lib/mongoid/association/eager_loadable.rb b/lib/mongoid/association/eager_loadable.rb index f57d2489ee..116b4b36dd 100644 --- a/lib/mongoid/association/eager_loadable.rb +++ b/lib/mongoid/association/eager_loadable.rb @@ -42,6 +42,9 @@ def preload(associations, docs) docs_map = {} queue = [ klass.to_s ] + # account for single-collection inheritance + queue.push(klass.root_class.to_s) if klass != klass.root_class + while klass = queue.shift if as = assoc_map.delete(klass) as.each do |assoc| diff --git a/lib/mongoid/traversable.rb b/lib/mongoid/traversable.rb index 349305640e..584340de19 100644 --- a/lib/mongoid/traversable.rb +++ b/lib/mongoid/traversable.rb @@ -44,6 +44,18 @@ def hereditary? !!(superclass < Mongoid::Document) end + # Returns the root class of the STI tree that the current + # class participates in. If the class is not an STI subclass, this + # returns the class itself. + # + # @return [ Mongoid::Document ] the root of the STI tree + def root_class + root = self + root = root.superclass while root.hereditary? + + root + end + # When inheriting, we want to copy the fields from the parent class and # set the on the child to start, mimicking the behavior of the old # class_inheritable_accessor that was deprecated in Rails edge. diff --git a/spec/mongoid/association/eager_spec.rb b/spec/mongoid/association/eager_spec.rb index 1706b0c42f..24889c423e 100644 --- a/spec/mongoid/association/eager_spec.rb +++ b/spec/mongoid/association/eager_spec.rb @@ -15,14 +15,36 @@ Mongoid::Contextual::Mongo.new(criteria) end + let(:association_host) { Account } + let(:inclusions) do includes.map do |key| - Account.reflect_on_association(key) + association_host.reflect_on_association(key) end end let(:doc) { criteria.first } + context 'when root is an STI subclass' do + # Driver has_one Vehicle + # Vehicle belongs_to Driver + # Truck is a Vehicle + + before do + Driver.create!(vehicle: Truck.new) + end + + let(:criteria) { Truck.all } + let(:includes) { %i[ driver ] } + let(:association_host) { Truck } + + it 'preloads the driver' do + expect(doc.ivar(:driver)).to be false + context.preload(inclusions, [ doc ]) + expect(doc.ivar(:driver)).to be == Driver.first + end + end + context "when belongs_to" do let!(:account) do @@ -43,7 +65,7 @@ it "preloads the parent" do expect(doc.ivar(:person)).to be false context.preload(inclusions, [doc]) - expect(doc.ivar(:person)).to eq(doc.person) + expect(doc.ivar(:person)).to be == person end end