Skip to content

[GR-17457] We no longer need to look into a Refinement's ancestors for method lookup #3561

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 55 additions & 123 deletions src/main/java/org/truffleruby/core/module/ModuleOperations.java
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ public static ConstantLookupResult lookupConstant(RubyContext context, RubyModul
private static ConstantLookupResult lookupConstant(RubyContext context, RubyModule module, String name,
ArrayList<Assumption> assumptions) {
// Look in the current module
ModuleFields fields = module.fields;
ConstantEntry constantEntry = fields.getOrComputeConstantEntry(name);
ConstantEntry constantEntry = module.fields.getOrComputeConstantEntry(name);
assumptions.add(constantEntry.getAssumption());
if (constantExists(constantEntry, assumptions)) {
return new ConstantLookupResult(constantEntry.getConstant(), toArray(assumptions));
Expand All @@ -162,8 +161,7 @@ private static ConstantLookupResult lookupConstant(RubyContext context, RubyModu
if (ancestor == module) {
continue;
}
fields = ancestor.fields;
constantEntry = fields.getOrComputeConstantEntry(name);
constantEntry = ancestor.fields.getOrComputeConstantEntry(name);
assumptions.add(constantEntry.getAssumption());
if (constantExists(constantEntry, assumptions)) {
return new ConstantLookupResult(constantEntry.getConstant(), toArray(assumptions));
Expand All @@ -179,16 +177,14 @@ public static ConstantLookupResult lookupConstantInObject(RubyContext context, S
ArrayList<Assumption> assumptions) {
final RubyClass objectClass = context.getCoreLibrary().objectClass;

ModuleFields fields = objectClass.fields;
ConstantEntry constantEntry = fields.getOrComputeConstantEntry(name);
ConstantEntry constantEntry = objectClass.fields.getOrComputeConstantEntry(name);
assumptions.add(constantEntry.getAssumption());
if (constantExists(constantEntry, assumptions)) {
return new ConstantLookupResult(constantEntry.getConstant(), toArray(assumptions));
}

for (RubyModule ancestor : objectClass.fields.prependedAndIncludedModules()) {
fields = ancestor.fields;
constantEntry = fields.getOrComputeConstantEntry(name);
constantEntry = ancestor.fields.getOrComputeConstantEntry(name);
assumptions.add(constantEntry.getAssumption());
if (constantExists(constantEntry, assumptions)) {
return new ConstantLookupResult(constantEntry.getConstant(), toArray(assumptions));
Expand All @@ -202,15 +198,13 @@ public static ConstantLookupResult lookupConstantInObject(RubyContext context, S
public static RubyConstant lookupConstantInObjectUncached(RubyContext context, String name) {
final RubyClass objectClass = context.getCoreLibrary().objectClass;

ModuleFields fields = objectClass.fields;
RubyConstant constant = fields.getConstant(name);
RubyConstant constant = objectClass.fields.getConstant(name);
if (constantExists(constant, null)) {
return constant;
}

for (RubyModule ancestor : objectClass.fields.prependedAndIncludedModules()) {
fields = ancestor.fields;
constant = fields.getConstant(name);
constant = ancestor.fields.getConstant(name);
if (constantExists(constant, null)) {
return constant;
}
Expand Down Expand Up @@ -243,8 +237,7 @@ public static ConstantLookupResult lookupConstantWithLexicalScope(RubyContext co

// Look in lexical scope
while (lexicalScope != context.getRootLexicalScope()) {
final ModuleFields fields = lexicalScope.getLiveModule().fields;
final ConstantEntry constantEntry = fields.getOrComputeConstantEntry(name);
final ConstantEntry constantEntry = lexicalScope.getLiveModule().fields.getOrComputeConstantEntry(name);
assumptions.add(constantEntry.getAssumption());
if (constantExists(constantEntry, assumptions)) {
return new ConstantLookupResult(constantEntry.getConstant(), toArray(assumptions));
Expand Down Expand Up @@ -328,8 +321,7 @@ public static ConstantLookupResult lookupConstantWithInherit(RubyContext context
return ModuleOperations.lookupConstant(context, module, name, assumptions);
}
} else {
final ModuleFields fields = module.fields;
final ConstantEntry constantEntry = fields.getOrComputeConstantEntry(name);
final ConstantEntry constantEntry = module.fields.getOrComputeConstantEntry(name);
assumptions.add(constantEntry.getAssumption());
if (constantExists(constantEntry, assumptions)) {
return new ConstantLookupResult(constantEntry.getConstant(), toArray(assumptions));
Expand Down Expand Up @@ -426,49 +418,25 @@ public static Map<String, InternalMethod> withoutUndefinedMethods(Map<String, In
return definedMethods;
}

public static MethodLookupResult lookupMethodCached(RubyModule module, String name,
DeclarationContext declarationContext) {
return lookupMethodCached(module, null, name, declarationContext);
}

@TruffleBoundary
private static MethodLookupResult lookupMethodCached(RubyModule module, RubyModule lookupTo, String name,
public static MethodLookupResult lookupMethodCached(RubyModule module, String name,
DeclarationContext declarationContext) {
final ArrayList<Assumption> assumptions = new ArrayList<>();
var assumptions = new ArrayList<Assumption>();

// Look in ancestors
for (RubyModule ancestor : module.fields.ancestors()) {
if (ancestor == lookupTo) {
return new MethodLookupResult(null, toArray(assumptions));
}
final RubyModule[] refinements = getRefinementsFor(declarationContext, ancestor);

var refinements = getRefinementsFor(declarationContext, ancestor);
if (refinements != null) {
for (RubyModule refinement : refinements) {
// If we have more then one active refinement for C (where C is refined module):
// R1.ancestors = [R1, A, C, ...]
// R2.ancestors = [R2, B, C, ...]
// R3.ancestors = [R3, D, C, ...]
// we are only looking up to C
// R3 -> D -> R2 -> B -> R1 -> A
final MethodLookupResult refinedMethod = lookupMethodCached(
refinement,
ancestor,
name,
null);
for (Assumption assumption : refinedMethod.getAssumptions()) {
assumptions.add(assumption);
}
if (refinedMethod.isDefined()) {
InternalMethod method = rememberUsedRefinements(refinedMethod.getMethod(), declarationContext);
var refinedMethod = refinement.fields.getMethodAndAssumption(name, assumptions);
if (refinedMethod != null) {
InternalMethod method = rememberUsedRefinements(refinedMethod, declarationContext);
return new MethodLookupResult(method, toArray(assumptions));
}
}
}

final ModuleFields fields = ancestor.fields;
final InternalMethod method = fields.getMethodAndAssumption(name, assumptions);

var method = ancestor.fields.getMethodAndAssumption(name, assumptions);
if (method != null) {
return new MethodLookupResult(method, toArray(assumptions));
}
Expand All @@ -478,37 +446,22 @@ private static MethodLookupResult lookupMethodCached(RubyModule module, RubyModu
return new MethodLookupResult(null, toArray(assumptions));
}

public static InternalMethod lookupMethodUncached(RubyModule module, String name,
DeclarationContext declarationContext) {
return lookupMethodUncached(module, null, name, declarationContext);
}

@TruffleBoundary
private static InternalMethod lookupMethodUncached(RubyModule module, RubyModule lookupTo, String name,
public static InternalMethod lookupMethodUncached(RubyModule module, String name,
DeclarationContext declarationContext) {

// Look in ancestors
for (RubyModule ancestor : module.fields.ancestors()) {
if (ancestor == lookupTo) {
return null;
}
final RubyModule[] refinements = getRefinementsFor(declarationContext, ancestor);

var refinements = getRefinementsFor(declarationContext, ancestor);
if (refinements != null) {
for (RubyModule refinement : refinements) {
final InternalMethod refinedMethod = lookupMethodUncached(
refinement,
ancestor,
name,
null);
var refinedMethod = refinement.fields.getMethod(name);
if (refinedMethod != null) {
return rememberUsedRefinements(refinedMethod, declarationContext);
}
}
}

final ModuleFields fields = ancestor.fields;
final InternalMethod method = fields.getMethod(name);

var method = ancestor.fields.getMethod(name);
if (method != null) {
return method;
}
Expand All @@ -518,82 +471,62 @@ private static InternalMethod lookupMethodUncached(RubyModule module, RubyModule
return null;
}

@TruffleBoundary
public static MethodLookupResult lookupSuperMethod(InternalMethod currentMethod, RubyModule objectMetaClass) {
final String name = currentMethod.getSharedMethodInfo().getMethodNameForNotBlock(); // use the original name
var name = currentMethod.getSharedMethodInfo().getMethodNameForNotBlock(); // use the original name

Memo<Boolean> foundDeclaringModule = new Memo<>(false);
return lookupSuperMethod(
currentMethod.getDeclaringModule(),
null,
name,
objectMetaClass,
foundDeclaringModule,
currentMethod.getDeclarationContext(),
currentMethod.getActiveRefinements());
}


@TruffleBoundary
private static MethodLookupResult lookupSuperMethod(RubyModule declaringModule, RubyModule lookupTo,
String name, RubyModule objectMetaClass, Memo<Boolean> foundDeclaringModule,
DeclarationContext declarationContext, DeclarationContext callerDeclaringContext) {
final ArrayList<Assumption> assumptions = new ArrayList<>();
final boolean isRefinedMethod = declaringModule.fields.isRefinement();
var foundDeclaringModule = new Memo<>(false);
var declaringModule = currentMethod.getDeclaringModule();
var declarationContext = currentMethod.getDeclarationContext();
var assumptions = new ArrayList<Assumption>();

// First we need to skip all ancestors until we find declaringModule,
// and then we return the first ancestor after declaringModule which has the method defined.
for (RubyModule ancestor : objectMetaClass.fields.ancestors()) {
if (ancestor == lookupTo) {
return new MethodLookupResult(null, toArray(assumptions));
}

final RubyModule[] refinements = getRefinementsFor(declarationContext, callerDeclaringContext, ancestor);

var refinements = getRefinementsFor(declarationContext, currentMethod.getActiveRefinements(), ancestor);
if (refinements != null) {
for (RubyModule refinement : refinements) {
final MethodLookupResult superMethodInRefinement = lookupSuperMethod(
declaringModule,
ancestor,
name,
refinement,
foundDeclaringModule,
null,
null);
for (Assumption assumption : superMethodInRefinement.getAssumptions()) {
assumptions.add(assumption);
}
if (superMethodInRefinement.isDefined()) {
InternalMethod method = superMethodInRefinement.getMethod();
var refinedMethod = lookupSuperMethodInModule(declaringModule, name, foundDeclaringModule,
refinement, assumptions);
if (refinedMethod != null) {
return new MethodLookupResult(
rememberUsedRefinements(method, declarationContext, refinements, ancestor),
rememberUsedRefinements(refinedMethod, declarationContext, refinements, ancestor),
toArray(assumptions));
}
if (foundDeclaringModule.get() && isRefinedMethod) {
if (foundDeclaringModule.get() && declaringModule.fields.isRefinement()) {
// if method is defined in refinement module (R)
// we should lookup only in this active refinement and skip other
// we should lookup only in this active refinement and skip others
break;
}
}
}

if (!foundDeclaringModule.get()) {
if (ancestor == declaringModule) {
// The declaring module's assumption needs to appended for cases where a newly included module
// should invalidate previous super lookups.
ancestor.fields.getMethodAndAssumption(name, assumptions);
foundDeclaringModule.set(true);
}
} else {
final ModuleFields fields = ancestor.fields;
final InternalMethod method = fields.getMethodAndAssumption(name, assumptions);
if (method != null) {
return new MethodLookupResult(method, toArray(assumptions));
}
var method = lookupSuperMethodInModule(declaringModule, name, foundDeclaringModule, ancestor, assumptions);
if (method != null) {
return new MethodLookupResult(method, toArray(assumptions));
}
}

// Nothing found
return new MethodLookupResult(null, toArray(assumptions));
}


private static InternalMethod lookupSuperMethodInModule(RubyModule declaringModule, String name,
Memo<Boolean> foundDeclaringModule, RubyModule module, ArrayList<Assumption> assumptions) {
if (!foundDeclaringModule.get()) {
if (module == declaringModule) {
// The declaring module's assumption needs to appended for cases where a newly included module
// should invalidate previous super lookups.
module.fields.getMethodAndAssumption(name, assumptions);
foundDeclaringModule.set(true);
}
return null;
} else {
return module.fields.getMethodAndAssumption(name, assumptions);
}
}

private static InternalMethod rememberUsedRefinements(InternalMethod method,
DeclarationContext declarationContext) {
return method.withActiveRefinements(declarationContext);
Expand All @@ -603,8 +536,7 @@ private static InternalMethod rememberUsedRefinements(InternalMethod method,
DeclarationContext declarationContext, RubyModule[] refinements, RubyModule ancestor) {
assert refinements != null;

final Map<RubyModule, RubyModule[]> currentRefinements = new HashMap<>(
declarationContext.getRefinements());
final Map<RubyModule, RubyModule[]> currentRefinements = new HashMap<>(declarationContext.getRefinements());
currentRefinements.put(ancestor, refinements);

return method.withActiveRefinements(declarationContext.withRefinements(currentRefinements));
Expand Down
25 changes: 10 additions & 15 deletions test/mri/excludes/TestRefinement.rb
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
exclude :test_eval_with_binding_scoping, "needs investigation"
exclude :test_include_refinement, "needs investigation"
exclude :test_prepend_after_refine, "needs investigation"
exclude :test_refine_prepended_class, "needs investigation"
exclude :test_refine_with_proc, "needs investigation"
exclude :test_undef_original_method, "needs investigation"
exclude :test_warn_setconst_in_refinmenet, "needs investigation"
exclude :test_refine_in_using, "needs investigation"
exclude :test_used_modules, "needs investigation"
exclude :test_unbound_refine_method, "needs investigation"
exclude :test_ancestors, "[ruby-core:86949] [Bug #14744]."
exclude :test_import_methods, "NoMethodError: undefined method `bar' for #<TestRefinement::TestImport::A:0x17bee78>"
exclude :test_eval_with_binding_scoping, "pid 123017 exit 0."
exclude :test_import_methods, "ArgumentError expected but nothing was raised."
exclude :test_prepend_after_refine, "<\"refined\"> expected but was"
exclude :test_refine_prepended_class, "<[:c, :m1, :m2]> expected but was"
exclude :test_refine_with_proc, "ArgumentError expected but nothing was raised."
exclude :test_unbound_refine_method, "TypeError expected but nothing was raised."
exclude :test_used_modules, "<[TestRefinement::VisibleRefinements::RefB,"
exclude :test_refinements, "TruffleRuby does not guarantee refinement list ordering"
exclude :test_refined_class, "TruffleRuby does not guarantee refinement list ordering"
exclude :test_prepend_into_refinement, "TypeError expected but nothing was raised."
exclude :test_include_into_refinement, "TypeError expected but nothing was raised."
exclude :test_refined_protected_methods, "assert_separately failed with error message"
exclude :test_warn_setconst_in_refinmenet, "[ruby-core:64143] [Bug #10103]"
exclude :test_refine_in_using, "NoMethodError: undefined method `foo' for #<TestRefinement::RefineInUsing:0xd88>"
exclude :test_refined_protected_methods, "NoMethodError: protected method `foo' called for #<C:0x2c8>"
Loading