Skip to content

Commit

Permalink
Add LOAD_METHOD_MODULE
Browse files Browse the repository at this point in the history
This speeds up things like `collections.deque()` where `collections` is
a module.

Add tests.
  • Loading branch information
tekknolagi committed May 31, 2023
1 parent 4208ecd commit 9558882
Show file tree
Hide file tree
Showing 6 changed files with 360 additions and 16 deletions.
3 changes: 2 additions & 1 deletion runtime/bytecode.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ namespace py {
V(UNUSED_BYTECODE_171, 171, doInvalidBytecode) \
V(UNUSED_BYTECODE_172, 172, doInvalidBytecode) \
V(UNUSED_BYTECODE_173, 173, doInvalidBytecode) \
V(UNUSED_BYTECODE_174, 174, doInvalidBytecode) \
V(LOAD_METHOD_MODULE, 174, doLoadMethodModule) \
V(CALL_FUNCTION_TYPE_INIT, 175, doCallFunctionTypeInit) \
V(CALL_FUNCTION_TYPE_NEW, 176, doCallFunctionTypeNew) \
V(CALL_FUNCTION_ANAMORPHIC, 177, doCallFunctionAnamorphic) \
Expand Down Expand Up @@ -393,6 +393,7 @@ inline bool isByteCodeWithCache(const Bytecode bc) {
case LOAD_ATTR_INSTANCE_TYPE_BOUND_METHOD:
case LOAD_ATTR_INSTANCE_TYPE_DESCR:
case LOAD_ATTR_MODULE:
case LOAD_METHOD_MODULE:
case LOAD_ATTR_TYPE:
case LOAD_ATTR_ANAMORPHIC:
case LOAD_METHOD_ANAMORPHIC:
Expand Down
28 changes: 28 additions & 0 deletions runtime/ic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,25 @@ void icUpdateAttrModule(Thread* thread, const MutableTuple& caches, word cache,
icInsertDependentToValueCellDependencyLink(thread, dependent, value_cell);
}

void icUpdateMethodModule(Thread* thread, const MutableTuple& caches,
word cache, const Object& receiver,
const ValueCell& value_cell,
const Function& dependent) {
DCHECK(icIsCacheEmpty(caches, cache), "cache must be empty\n");
HandleScope scope(thread);
word index = cache * kIcPointersPerEntry;
Module module(&scope, *receiver);
caches.atPut(index + kIcEntryKeyOffset, SmallInt::fromWord(module.id()));
caches.atPut(index + kIcEntryValueOffset, *value_cell);
RawMutableBytes bytecode =
RawMutableBytes::cast(dependent.rewrittenBytecode());
word pc = thread->currentFrame()->virtualPC() - kCodeUnitSize;
DCHECK(bytecode.byteAt(pc) == LOAD_METHOD_ANAMORPHIC,
"current opcode must be LOAD_METHOD_ANAMORPHIC");
bytecode.byteAtPut(pc, LOAD_METHOD_MODULE);
icInsertDependentToValueCellDependencyLink(thread, dependent, value_cell);
}

void icUpdateAttrType(Thread* thread, const MutableTuple& caches, word cache,
const Object& receiver, const Object& selector,
const Object& value, const Function& dependent) {
Expand Down Expand Up @@ -822,6 +841,15 @@ void icInvalidateGlobalVar(Thread* thread, const ValueCell& value_cell) {
}
break;
}
case LOAD_METHOD_MODULE: {
original_bc = LOAD_METHOD_ANAMORPHIC;
word index = op.cache * kIcPointersPerEntry;
if (caches.at(index + kIcEntryValueOffset) == *value_cell) {
caches.atPut(index + kIcEntryKeyOffset, NoneType::object());
caches.atPut(index + kIcEntryValueOffset, NoneType::object());
}
break;
}
case LOAD_GLOBAL_CACHED:
original_bc = LOAD_GLOBAL;
if (op.bc != original_bc && op.arg == name_index_found) {
Expand Down
10 changes: 9 additions & 1 deletion runtime/ic.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ void icUpdateAttrModule(Thread* thread, const MutableTuple& caches, word cache,
const Object& receiver, const ValueCell& value_cell,
const Function& dependent);

void icUpdateMethodModule(Thread* thread, const MutableTuple& caches,
word cache, const Object& receiver,
const ValueCell& value_cell,
const Function& dependent);

void icUpdateAttrType(Thread* thread, const MutableTuple& caches, word cache,
const Object& receiver, const Object& selector,
const Object& value, const Function& dependent);
Expand Down Expand Up @@ -361,7 +366,10 @@ class IcIterator {
}
}

bool isModuleAttrCache() const { return bytecode_op_.bc == LOAD_ATTR_MODULE; }
bool isModuleAttrCache() const {
return bytecode_op_.bc == LOAD_ATTR_MODULE ||
bytecode_op_.bc == LOAD_METHOD_MODULE;
}

bool isBinaryOpCache() const {
switch (bytecode_op_.bc) {
Expand Down
232 changes: 232 additions & 0 deletions runtime/interpreter-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6737,6 +6737,238 @@ c = C()
EXPECT_TRUE(isIntEqualsWord(Interpreter::call0(thread_, test_function), 70));
}

TEST_F(InterpreterTest, LoadMethodCachedModuleFunction) {
EXPECT_FALSE(runFromCStr(runtime_, R"(
import sys
class C:
def getdefaultencoding(self):
return "no-utf8"
def test(obj):
return obj.getdefaultencoding()
cached = sys.getdefaultencoding
obj = C()
)")
.isError());
HandleScope scope(thread_);
Function test_function(&scope, mainModuleAt(runtime_, "test"));
Function expected_value(&scope, mainModuleAt(runtime_, "cached"));
MutableBytes bytecode(&scope, test_function.rewrittenBytecode());
ASSERT_EQ(rewrittenBytecodeOpAt(bytecode, 1), LOAD_METHOD_ANAMORPHIC);
ASSERT_EQ(rewrittenBytecodeOpAt(bytecode, 2), CALL_METHOD);

// Cache miss.
Module sys_module(&scope, runtime_->findModuleById(ID(sys)));
MutableTuple caches(&scope, test_function.caches());
word cache_index =
rewrittenBytecodeCacheAt(bytecode, 1) * kIcPointersPerEntry;
Object key(&scope, caches.at(cache_index + kIcEntryKeyOffset));
EXPECT_EQ(*key, NoneType::object());

// Call.
EXPECT_TRUE(isStrEqualsCStr(
Interpreter::call1(thread_, test_function, sys_module), "utf-8"));
EXPECT_EQ(rewrittenBytecodeOpAt(bytecode, 1), LOAD_METHOD_MODULE);

// Cache hit.
key = caches.at(cache_index + kIcEntryKeyOffset);
EXPECT_TRUE(isIntEqualsWord(*key, sys_module.id()));
Object value(&scope, caches.at(cache_index + kIcEntryValueOffset));
ASSERT_TRUE(value.isValueCell());
EXPECT_EQ(ValueCell::cast(*value).value(), *expected_value);

// Call.
EXPECT_TRUE(isStrEqualsCStr(
Interpreter::call1(thread_, test_function, sys_module), "utf-8"));

// Rewrite.
Object obj(&scope, mainModuleAt(runtime_, "obj"));
EXPECT_TRUE(isStrEqualsCStr(Interpreter::call1(thread_, test_function, obj),
"no-utf8"));
EXPECT_EQ(rewrittenBytecodeOpAt(bytecode, 1), LOAD_METHOD_INSTANCE_FUNCTION);
key = caches.at(cache_index + kIcEntryKeyOffset);
EXPECT_FALSE(key.isValueCell());
}

TEST_F(InterpreterTest,
LoadMethodWithModuleAndNonFunctionRewritesToLoadMethodModule) {
EXPECT_FALSE(runFromCStr(runtime_, R"(
import sys
class C:
def __call__(self):
return 123
mymodule = type(sys)("mymodule")
mymodule.getdefaultencoding = C()
def test(obj):
return obj.getdefaultencoding()
)")
.isError());
HandleScope scope(thread_);
Function test_function(&scope, mainModuleAt(runtime_, "test"));
MutableBytes bytecode(&scope, test_function.rewrittenBytecode());
Module mymodule(&scope, mainModuleAt(runtime_, "mymodule"));
ASSERT_EQ(rewrittenBytecodeOpAt(bytecode, 1), LOAD_METHOD_ANAMORPHIC);
ASSERT_EQ(rewrittenBytecodeOpAt(bytecode, 2), CALL_METHOD);

// Cache miss.
MutableTuple caches(&scope, test_function.caches());
word cache_index =
rewrittenBytecodeCacheAt(bytecode, 1) * kIcPointersPerEntry;
Object key(&scope, caches.at(cache_index + kIcEntryKeyOffset));
EXPECT_EQ(*key, NoneType::object());

// Call.
EXPECT_TRUE(isIntEqualsWord(
Interpreter::call1(thread_, test_function, mymodule), 123));
EXPECT_EQ(rewrittenBytecodeOpAt(bytecode, 1), LOAD_METHOD_MODULE);
}

TEST_F(InterpreterTest, LoadMethodModuleGetsEvicted) {
EXPECT_FALSE(runFromCStr(runtime_, R"(
import sys
def test(obj):
return obj.getdefaultencoding()
)")
.isError());
HandleScope scope(thread_);
Function test_function(&scope, mainModuleAt(runtime_, "test"));
MutableBytes bytecode(&scope, test_function.rewrittenBytecode());
ASSERT_EQ(rewrittenBytecodeOpAt(bytecode, 1), LOAD_METHOD_ANAMORPHIC);
ASSERT_EQ(rewrittenBytecodeOpAt(bytecode, 2), CALL_METHOD);

// Cache miss.
Module sys_module(&scope, runtime_->findModuleById(ID(sys)));
MutableTuple caches(&scope, test_function.caches());
word cache_index =
rewrittenBytecodeCacheAt(bytecode, 1) * kIcPointersPerEntry;
Object key(&scope, caches.at(cache_index + kIcEntryKeyOffset));
EXPECT_EQ(*key, NoneType::object());

// Call.
EXPECT_TRUE(isStrEqualsCStr(
Interpreter::call1(thread_, test_function, sys_module), "utf-8"));
EXPECT_EQ(rewrittenBytecodeOpAt(bytecode, 1), LOAD_METHOD_MODULE);

// Update module.
Str getdefaultencoding(
&scope, runtime_->internStrFromCStr(thread_, "getdefaultencoding"));
Object result(&scope,
moduleDeleteAttribute(thread_, sys_module, getdefaultencoding));
ASSERT_TRUE(result.isNoneType());

// Cache is empty.
key = caches.at(cache_index + kIcEntryKeyOffset);
EXPECT_TRUE(key.isNoneType());

// Cache miss.
EXPECT_TRUE(
raisedWithStr(Interpreter::call1(thread_, test_function, sys_module),
LayoutId::kAttributeError,
"module 'sys' has no attribute 'getdefaultencoding'"));

// Bytecode gets rewritten after next call.
EXPECT_EQ(rewrittenBytecodeOpAt(bytecode, 1), LOAD_METHOD_ANAMORPHIC);
}

TEST_F(InterpreterTest, LoadMethodModuleWithModuleMismatchUpdatesCache) {
EXPECT_FALSE(runFromCStr(runtime_, R"(
import sys
mymodule = type(sys)("mymodule")
mymodule.getdefaultencoding = lambda: "hello"
def test(obj):
return obj.getdefaultencoding()
)")
.isError());
HandleScope scope(thread_);
Function test_function(&scope, mainModuleAt(runtime_, "test"));
Module mymodule(&scope, mainModuleAt(runtime_, "mymodule"));
MutableBytes bytecode(&scope, test_function.rewrittenBytecode());
ASSERT_EQ(rewrittenBytecodeOpAt(bytecode, 1), LOAD_METHOD_ANAMORPHIC);
ASSERT_EQ(rewrittenBytecodeOpAt(bytecode, 2), CALL_METHOD);

// Cache miss.
Module sys_module(&scope, runtime_->findModuleById(ID(sys)));
MutableTuple caches(&scope, test_function.caches());
word cache_index =
rewrittenBytecodeCacheAt(bytecode, 1) * kIcPointersPerEntry;
Object key(&scope, caches.at(cache_index + kIcEntryKeyOffset));
EXPECT_EQ(*key, NoneType::object());

// Call.
EXPECT_TRUE(isStrEqualsCStr(
Interpreter::call1(thread_, test_function, sys_module), "utf-8"));
EXPECT_EQ(rewrittenBytecodeOpAt(bytecode, 1), LOAD_METHOD_MODULE);

// Cache contains sys.
key = caches.at(cache_index + kIcEntryKeyOffset);
EXPECT_TRUE(isIntEqualsWord(*key, sys_module.id()));

// Call.
EXPECT_TRUE(isStrEqualsCStr(
Interpreter::call1(thread_, test_function, mymodule), "hello"));
EXPECT_EQ(rewrittenBytecodeOpAt(bytecode, 1), LOAD_METHOD_MODULE);

// Cache contains mymodule.
key = caches.at(cache_index + kIcEntryKeyOffset);
EXPECT_TRUE(isIntEqualsWord(*key, mymodule.id()));
}

TEST_F(InterpreterTest, LoadMethodModuleGetsScannedInOtherEviction) {
EXPECT_FALSE(runFromCStr(runtime_, R"(
import sys
class C:
def __init__(self):
self.foo = 123
c = C()
def test(obj):
c.foo
return obj.getdefaultencoding()
def invalidate():
C.foo = property(lambda self: 456)
)")
.isError());
HandleScope scope(thread_);
Function test_function(&scope, mainModuleAt(runtime_, "test"));
Function invalidate(&scope, mainModuleAt(runtime_, "invalidate"));
MutableBytes bytecode(&scope, test_function.rewrittenBytecode());
ASSERT_EQ(rewrittenBytecodeOpAt(bytecode, 4), LOAD_METHOD_ANAMORPHIC);
ASSERT_EQ(rewrittenBytecodeOpAt(bytecode, 5), CALL_METHOD);

// Cache miss.
Module sys_module(&scope, runtime_->findModuleById(ID(sys)));
MutableTuple caches(&scope, test_function.caches());
word cache_index =
rewrittenBytecodeCacheAt(bytecode, 4) * kIcPointersPerEntry;
Object key(&scope, caches.at(cache_index + kIcEntryKeyOffset));
EXPECT_EQ(*key, NoneType::object());

// Call.
EXPECT_TRUE(isStrEqualsCStr(
Interpreter::call1(thread_, test_function, sys_module), "utf-8"));
EXPECT_EQ(rewrittenBytecodeOpAt(bytecode, 4), LOAD_METHOD_MODULE);

// Evict the caches in the `test' function.
ASSERT_TRUE(Interpreter::call0(thread_, invalidate).isNoneType());

// The LOAD_METHOD_MODULE is not affected.
EXPECT_EQ(rewrittenBytecodeOpAt(bytecode, 4), LOAD_METHOD_MODULE);
EXPECT_TRUE(isStrEqualsCStr(
Interpreter::call1(thread_, test_function, sys_module), "utf-8"));
EXPECT_EQ(rewrittenBytecodeOpAt(bytecode, 4), LOAD_METHOD_MODULE);
}

TEST_F(InterpreterTest, LoadMethodCachedDoesNotCacheProperty) {
HandleScope scope(thread_);
EXPECT_FALSE(runFromCStr(runtime_, R"(
Expand Down
Loading

0 comments on commit 9558882

Please sign in to comment.