Use local type inference to devirtualize calls after inlining

This commit is contained in:
Alexey Andreev 2017-01-14 00:50:22 +03:00
parent 645b2b7cd5
commit 98b6fff2f0
5 changed files with 117 additions and 25 deletions

View File

@ -73,6 +73,7 @@ public class Linker {
if (!fieldRef.getClassName().equals(method.getOwnerName())) { if (!fieldRef.getClassName().equals(method.getOwnerName())) {
InitClassInstruction initInsn = new InitClassInstruction(); InitClassInstruction initInsn = new InitClassInstruction();
initInsn.setClassName(fieldRef.getClassName()); initInsn.setClassName(fieldRef.getClassName());
initInsn.setLocation(insn.getLocation());
insn.insertPrevious(initInsn); insn.insertPrevious(initInsn);
} }

View File

@ -231,15 +231,17 @@ public class ClassInference {
MethodReference resolvedMethodRef = resolvedMethod.getReference(); MethodReference resolvedMethodRef = resolvedMethod.getReference();
if (callSite.resolvedMethods.add(resolvedMethodRef)) { if (callSite.resolvedMethods.add(resolvedMethodRef)) {
MethodDependencyInfo methodDep = dependencyInfo.getMethod(resolvedMethodRef); MethodDependencyInfo methodDep = dependencyInfo.getMethod(resolvedMethodRef);
if (callSite.receiver >= 0) { if (methodDep != null) {
readValue(methodDep.getResult(), program.variableAt(callSite.receiver), queue); if (callSite.receiver >= 0) {
} readValue(methodDep.getResult(), program.variableAt(callSite.receiver), queue);
for (int i = 0; i < callSite.arguments.length; ++i) { }
writeValue(methodDep.getVariable(i + 1), program.variableAt(callSite.arguments[i]), for (int i = 0; i < callSite.arguments.length; ++i) {
queue); writeValue(methodDep.getVariable(i + 1), program.variableAt(callSite.arguments[i]),
} queue);
for (String type : methodDep.getThrown().getTypes()) { }
queue.add(new Task(callSite.block, -1, type)); for (String type : methodDep.getThrown().getTypes()) {
queue.add(new Task(callSite.block, -1, type));
}
} }
} }
} }
@ -252,9 +254,15 @@ public class ClassInference {
IntObjectMap<ValueType> variableCasts = casts.get(task.variable); IntObjectMap<ValueType> variableCasts = casts.get(task.variable);
if (variableCasts != null) { if (variableCasts != null) {
ValueType type = task.className.startsWith("[") ValueType type;
? ValueType.parse(task.className) if (task.className.startsWith("[")) {
: ValueType.object(task.className); type = ValueType.parseIfPossible(task.className);
if (type == null) {
type = ValueType.arrayOf(ValueType.object("java.lang.Object"));
}
} else {
type = ValueType.object(task.className);
}
for (int target : variableCasts.keys().toArray()) { for (int target : variableCasts.keys().toArray()) {
ValueType targetType = variableCasts.get(target); ValueType targetType = variableCasts.get(target);
if (classSource.isSuperType(targetType, type).orElse(false)) { if (classSource.isSuperType(targetType, type).orElse(false)) {

View File

@ -15,11 +15,15 @@
*/ */
package org.teavm.model.optimization; package org.teavm.model.optimization;
import com.carrotsearch.hppc.IntArrayList;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.teavm.dependency.DependencyInfo;
import org.teavm.model.BasicBlock; import org.teavm.model.BasicBlock;
import org.teavm.model.ClassReader; import org.teavm.model.ClassReader;
import org.teavm.model.ClassReaderSource; import org.teavm.model.ClassReaderSource;
@ -30,6 +34,7 @@ import org.teavm.model.MethodReference;
import org.teavm.model.Phi; import org.teavm.model.Phi;
import org.teavm.model.Program; import org.teavm.model.Program;
import org.teavm.model.TryCatchBlock; import org.teavm.model.TryCatchBlock;
import org.teavm.model.analysis.ClassInference;
import org.teavm.model.instructions.AssignInstruction; import org.teavm.model.instructions.AssignInstruction;
import org.teavm.model.instructions.BinaryBranchingInstruction; import org.teavm.model.instructions.BinaryBranchingInstruction;
import org.teavm.model.instructions.BranchingInstruction; import org.teavm.model.instructions.BranchingInstruction;
@ -48,13 +53,35 @@ import org.teavm.model.util.ProgramUtils;
public class Inlining { public class Inlining {
private static final int DEFAULT_THRESHOLD = 15; private static final int DEFAULT_THRESHOLD = 15;
private static final int MAX_DEPTH = 5; private static final int MAX_DEPTH = 5;
private IntArrayList depthsByBlock;
private Set<Instruction> instructionsToSkip;
public void apply(Program program, MethodReference method, ClassReaderSource classes,
DependencyInfo dependencyInfo) {
depthsByBlock = new IntArrayList(program.basicBlockCount());
for (int i = 0; i < program.basicBlockCount(); ++i) {
depthsByBlock.add(0);
}
instructionsToSkip = new HashSet<>();
while (applyOnce(program, classes)) {
devirtualize(program, method, dependencyInfo);
}
depthsByBlock = null;
instructionsToSkip = null;
public void apply(Program program, ClassReaderSource classSource) {
List<PlanEntry> plan = buildPlan(program, classSource, 0);
execPlan(program, plan, 0);
new UnreachableBasicBlockEliminator().optimize(program); new UnreachableBasicBlockEliminator().optimize(program);
} }
private boolean applyOnce(Program program, ClassReaderSource classSource) {
List<PlanEntry> plan = buildPlan(program, classSource, 0);
if (plan.isEmpty()) {
return false;
}
execPlan(program, plan, 0);
return true;
}
private void execPlan(Program program, List<PlanEntry> plan, int offset) { private void execPlan(Program program, List<PlanEntry> plan, int offset) {
for (PlanEntry entry : plan) { for (PlanEntry entry : plan) {
execPlanEntry(program, entry, offset); execPlanEntry(program, entry, offset);
@ -70,6 +97,9 @@ public class Inlining {
for (int i = 1; i < inlineProgram.basicBlockCount(); ++i) { for (int i = 1; i < inlineProgram.basicBlockCount(); ++i) {
program.createBasicBlock(); program.createBasicBlock();
} }
while (depthsByBlock.size() < program.basicBlockCount()) {
depthsByBlock.add(planEntry.depth);
}
int variableOffset = program.variableCount(); int variableOffset = program.variableCount();
for (int i = 0; i < inlineProgram.variableCount(); ++i) { for (int i = 0; i < inlineProgram.variableCount(); ++i) {
@ -186,12 +216,22 @@ public class Inlining {
} }
List<PlanEntry> plan = new ArrayList<>(); List<PlanEntry> plan = new ArrayList<>();
int ownComplexity = getComplexity(program); int ownComplexity = getComplexity(program);
int originalDepth = depth;
for (BasicBlock block : program.getBasicBlocks()) { for (BasicBlock block : program.getBasicBlocks()) {
if (!block.getTryCatchBlocks().isEmpty()) { if (!block.getTryCatchBlocks().isEmpty()) {
continue; continue;
} }
if (originalDepth == 0) {
depth = depthsByBlock.get(block.getIndex());
}
for (Instruction insn : block) { for (Instruction insn : block) {
if (instructionsToSkip.contains(insn)) {
continue;
}
if (!(insn instanceof InvokeInstruction)) { if (!(insn instanceof InvokeInstruction)) {
continue; continue;
} }
@ -203,15 +243,17 @@ public class Inlining {
MethodReader invokedMethod = getMethod(classSource, invoke.getMethod()); MethodReader invokedMethod = getMethod(classSource, invoke.getMethod());
if (invokedMethod == null || invokedMethod.getProgram() == null if (invokedMethod == null || invokedMethod.getProgram() == null
|| invokedMethod.getProgram().basicBlockCount() == 0) { || invokedMethod.getProgram().basicBlockCount() == 0) {
instructionsToSkip.add(insn);
continue; continue;
} }
Program invokedProgram = ProgramUtils.copy(invokedMethod.getProgram()); Program invokedProgram = ProgramUtils.copy(invokedMethod.getProgram());
int complexityThreshold = DEFAULT_THRESHOLD - depth * 2; int complexityThreshold = DEFAULT_THRESHOLD;
if (ownComplexity < DEFAULT_THRESHOLD) { if (ownComplexity < DEFAULT_THRESHOLD) {
complexityThreshold += DEFAULT_THRESHOLD; complexityThreshold += DEFAULT_THRESHOLD;
} }
if (getComplexity(invokedProgram) > complexityThreshold) { if (getComplexity(invokedProgram) > complexityThreshold) {
instructionsToSkip.add(insn);
continue; continue;
} }
@ -220,6 +262,7 @@ public class Inlining {
entry.targetInstruction = insn; entry.targetInstruction = insn;
entry.program = invokedProgram; entry.program = invokedProgram;
entry.innerPlan.addAll(buildPlan(invokedProgram, classSource, depth + 1)); entry.innerPlan.addAll(buildPlan(invokedProgram, classSource, depth + 1));
entry.depth = depth;
plan.add(entry); plan.add(entry);
} }
} }
@ -261,10 +304,42 @@ public class Inlining {
return complexity; return complexity;
} }
private void devirtualize(Program program, MethodReference method, DependencyInfo dependencyInfo) {
ClassInference inference = new ClassInference(dependencyInfo);
inference.infer(program, method);
for (BasicBlock block : program.getBasicBlocks()) {
for (Instruction instruction : block) {
if (!(instruction instanceof InvokeInstruction)) {
continue;
}
InvokeInstruction invoke = (InvokeInstruction) instruction;
if (invoke.getType() != InvocationType.VIRTUAL) {
continue;
}
Set<MethodReference> implementations = new HashSet<>();
for (String className : inference.classesOf(invoke.getInstance().getIndex())) {
MethodReference rawMethod = new MethodReference(className, invoke.getMethod().getDescriptor());
MethodReader resolvedMethod = dependencyInfo.getClassSource().resolve(rawMethod);
if (resolvedMethod != null) {
implementations.add(resolvedMethod.getReference());
}
}
if (implementations.size() == 1) {
invoke.setType(InvocationType.SPECIAL);
invoke.setMethod(implementations.iterator().next());
}
}
}
}
private class PlanEntry { private class PlanEntry {
int targetBlock; int targetBlock;
Instruction targetInstruction; Instruction targetInstruction;
Program program; Program program;
int depth;
final List<PlanEntry> innerPlan = new ArrayList<>(); final List<PlanEntry> innerPlan = new ArrayList<>();
} }
} }

View File

@ -1695,11 +1695,6 @@ public class ProgramParser {
break; break;
} }
case Opcodes.PUTSTATIC: { case Opcodes.PUTSTATIC: {
if (!owner.equals(currentClassName)) {
InitClassInstruction initInsn = new InitClassInstruction();
initInsn.setClassName(ownerCls);
addInstruction(initInsn);
}
int value = desc.equals("D") || desc.equals("J") ? popDouble() : popSingle(); int value = desc.equals("D") || desc.equals("J") ? popDouble() : popSingle();
PutFieldInstruction insn = new PutFieldInstruction(); PutFieldInstruction insn = new PutFieldInstruction();
insn.setField(referenceCache.getCached(new FieldReference(ownerCls, name))); insn.setField(referenceCache.getCached(new FieldReference(ownerCls, name)));

View File

@ -368,7 +368,7 @@ public class TeaVM implements TeaVMHost, ServiceRepository {
return; return;
} }
inline(classSet); inline(classSet, dependencyChecker);
if (wasCancelled()) { if (wasCancelled()) {
return; return;
} }
@ -435,24 +435,37 @@ public class TeaVM implements TeaVMHost, ServiceRepository {
} }
} }
private void inline(ListableClassHolderSource classes) { private void inline(ListableClassHolderSource classes, DependencyInfo dependencyInfo) {
if (optimizationLevel != TeaVMOptimizationLevel.FULL) { if (optimizationLevel != TeaVMOptimizationLevel.FULL) {
return; return;
} }
Map<MethodReference, Program> inlinedPrograms = new HashMap<>();
Inlining inlining = new Inlining(); Inlining inlining = new Inlining();
for (String className : classes.getClassNames()) { for (String className : classes.getClassNames()) {
ClassHolder cls = classes.get(className); ClassHolder cls = classes.get(className);
for (MethodHolder method : cls.getMethods()) { for (MethodHolder method : cls.getMethods()) {
if (method.getProgram() != null) { if (method.getProgram() != null) {
Program program = ProgramUtils.copy(method.getProgram());
MethodOptimizationContextImpl context = new MethodOptimizationContextImpl(method, classes); MethodOptimizationContextImpl context = new MethodOptimizationContextImpl(method, classes);
inlining.apply(method.getProgram(), classes); inlining.apply(program, method.getReference(), classes, dependencyInfo);
new UnusedVariableElimination().optimize(context, method.getProgram()); new UnusedVariableElimination().optimize(context, program);
inlinedPrograms.put(method.getReference(), program);
} }
} }
if (wasCancelled()) { if (wasCancelled()) {
return; return;
} }
} }
for (String className : classes.getClassNames()) {
ClassHolder cls = classes.get(className);
for (MethodHolder method : cls.getMethods()) {
if (method.getProgram() != null) {
method.setProgram(inlinedPrograms.get(method.getReference()));
}
}
}
} }
private void optimize(ListableClassHolderSource classSource) { private void optimize(ListableClassHolderSource classSource) {