Improve performance of inliner in FULL optimization level

The former implementation suffered from class inference.
The reason was in many nodes having too many possible classes in them.
The new implementation does not handle full set of classes in each node.
Instead, it introduces concept of 'overflow', i.e. node having
more types than the given upper limit.
These nodes behave as if there were all possible classes in them,
which allows to apply certain optimization for these nodes and
omit heavy computations of large type sets.
This commit is contained in:
Alexey Andreev 2019-09-26 19:40:23 +03:00
parent 9314461fcf
commit a2a9dbcbe3
7 changed files with 542 additions and 94 deletions

View File

@ -434,6 +434,14 @@ public class DependencyNode implements ValueDependencyInfo {
return i == result.length ? result : Arrays.copyOf(result, i);
}
@Override
public boolean hasMoreTypesThan(int limit) {
if (typeSet == null) {
return false;
}
return typeSet.hasMoreTypesThan(limit, typeFilter != null ? getFilter()::match : null);
}
DependencyType[] getTypesInternal() {
if (typeSet == null) {
return new DependencyType[0];

View File

@ -23,6 +23,7 @@ import java.util.BitSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Predicate;
class TypeSet {
private static final int SMALL_TYPES_THRESHOLD = 3;
@ -89,6 +90,41 @@ class TypeSet {
}
}
boolean hasMoreTypesThan(int limit, Predicate<DependencyType> filter) {
if (this.types != null) {
if (filter == null) {
return this.types.cardinality() > limit;
}
for (int index = this.types.nextSetBit(0); index >= 0; index = this.types.nextSetBit(index + 1)) {
DependencyType type = dependencyAnalyzer.types.get(index);
if (filter.test(type)) {
if (--limit < 0) {
return true;
}
}
}
return false;
} else if (this.smallTypes != null) {
if (this.smallTypes.length <= limit) {
return false;
}
if (filter == null) {
return true;
}
for (int i = 0; i < smallTypes.length; ++i) {
DependencyType type = dependencyAnalyzer.types.get(smallTypes[i]);
if (filter.test(type)) {
if (--limit < 0) {
return true;
}
}
}
return false;
} else {
return false;
}
}
DependencyType[] getTypesForNode(DependencyNode sourceNode, DependencyNode targetNode,
DependencyTypeFilter filter) {
int j = 0;

View File

@ -20,6 +20,8 @@ public interface ValueDependencyInfo {
boolean hasType(String type);
boolean hasMoreTypesThan(int limit);
boolean hasArrayType();
ValueDependencyInfo getArrayItem();

View File

@ -39,6 +39,7 @@ import org.teavm.model.ClassHierarchy;
import org.teavm.model.ClassReaderSource;
import org.teavm.model.Incoming;
import org.teavm.model.Instruction;
import org.teavm.model.MethodDescriptor;
import org.teavm.model.MethodReader;
import org.teavm.model.MethodReference;
import org.teavm.model.Phi;
@ -68,6 +69,7 @@ import org.teavm.model.instructions.UnwrapArrayInstruction;
public class ClassInference {
private DependencyInfo dependencyInfo;
private ClassHierarchy hierarchy;
private SubclassListProvider subclassListProvider;
private Graph assignmentGraph;
private Graph cloneGraph;
private Graph arrayGraph;
@ -76,11 +78,13 @@ public class ClassInference {
private ValueCast[] casts;
private int[] exceptions;
private VirtualCallSite[] virtualCallSites;
private int overflowLimit;
private int[] propagationPath;
private int[] nodeMapping;
private IntHashSet[] types;
private boolean[] overflowTypes;
private ObjectIntMap<String> typeMap = new ObjectIntHashMap<>();
private List<String> typeList = new ArrayList<>();
@ -90,9 +94,12 @@ public class ClassInference {
private static final int MAX_DEGREE = 3;
public ClassInference(DependencyInfo dependencyInfo, ClassHierarchy hierarchy) {
public ClassInference(DependencyInfo dependencyInfo, ClassHierarchy hierarchy,
Iterable<? extends String> classNames, int overflowLimit) {
this.dependencyInfo = dependencyInfo;
this.hierarchy = hierarchy;
this.overflowLimit = overflowLimit;
subclassListProvider = new SubclassListProvider(dependencyInfo.getClassSource(), classNames, overflowLimit);
}
public void infer(Program program, MethodReference methodReference) {
@ -112,6 +119,7 @@ public class ClassInference {
*/
types = new IntHashSet[program.variableCount() << 3];
overflowTypes = new boolean[program.variableCount() << 3];
nodeChanged = new boolean[types.length];
formerNodeChanged = new boolean[nodeChanged.length];
nodeMapping = new int[types.length];
@ -129,9 +137,7 @@ public class ClassInference {
if (paramDep != null) {
int degree = 0;
while (degree <= MAX_DEGREE) {
for (String paramType : paramDep.getTypes()) {
addType(i, degree, paramType);
}
addTypesFrom(i, degree, paramDep);
if (!paramDep.hasArrayType()) {
break;
}
@ -166,6 +172,14 @@ public class ClassInference {
nodeChanged = null;
}
public boolean isOverflow(int variableIndex) {
return overflowTypes[nodeMapping[packNodeAndDegree(variableIndex, 0)]];
}
public List<? extends MethodReference> getMethodImplementations(MethodDescriptor descriptor) {
return subclassListProvider.getMethods(descriptor);
}
public String[] classesOf(int variableIndex) {
IntHashSet typeSet = types[nodeMapping[packNodeAndDegree(variableIndex, 0)]];
if (typeSet == null) {
@ -210,7 +224,7 @@ public class ClassInference {
IntStack stack = new IntStack();
for (int i = 0; i < types.length; ++i) {
if (types[i] != null) {
if (types[i] != null || overflowTypes[i]) {
stack.push(i);
}
}
@ -286,8 +300,10 @@ public class ClassInference {
boolean[] nodeChangedBackup = nodeChanged.clone();
IntHashSet[] typesBackup = types.clone();
boolean[] overflowTypesBackup = overflowTypes.clone();
Arrays.fill(nodeChanged, false);
Arrays.fill(types, null);
Arrays.fill(overflowTypes, false);
GraphBuilder graphBuilder = new GraphBuilder(graph.size());
for (int i = 0; i < graph.size(); ++i) {
@ -300,8 +316,17 @@ public class ClassInference {
}
int node = nodeMapping[i];
if (typesBackup[i] != null) {
getNodeTypes(node).addAll(typesBackup[i]);
if (overflowTypesBackup[i]) {
overflowTypes[i] = true;
types[node] = null;
} else if (typesBackup[i] != null && !overflowTypes[i]) {
IntHashSet nodeTypes = getNodeTypes(node);
if (nodeTypes.addAll(typesBackup[i]) > 0) {
if (nodeTypes.size() > overflowLimit) {
types[node] = null;
overflowTypes[node] = true;
}
}
}
if (nodeChangedBackup[i]) {
@ -323,7 +348,7 @@ public class ClassInference {
IntStack stack = new IntStack();
for (int i = 0; i < graph.size(); ++i) {
if (graph.incomingEdgesCount(i) == 0 && types[i] != null) {
if (graph.incomingEdgesCount(i) == 0 && (overflowTypes[i] || types[i] != null)) {
stack.push(i);
}
}
@ -381,6 +406,10 @@ public class ClassInference {
private void propagateAlongDAG() {
for (int i = propagationPath.length - 1; i >= 0; --i) {
int node = propagationPath[i];
if (overflowTypes[node]) {
continue;
}
boolean predecessorsChanged = false;
for (int predecessor : graph.incomingEdges(node)) {
if (formerNodeChanged[predecessor] || nodeChanged[predecessor]) {
@ -395,9 +424,25 @@ public class ClassInference {
IntHashSet nodeTypes = getNodeTypes(node);
for (int predecessor : graph.incomingEdges(node)) {
if (formerNodeChanged[predecessor] || nodeChanged[predecessor]) {
if (nodeTypes.addAll(types[predecessor]) > 0) {
nodeChanged[node] = true;
changed = true;
if (overflowTypes[predecessor]) {
if (!overflowTypes[node]) {
overflowTypes[node] = true;
types[node] = null;
changed = true;
nodeChanged[node] = true;
break;
}
} else if (types[predecessor] != null && nodeTypes.addAll(types[predecessor]) > 0) {
if (nodeTypes.size() > overflowLimit) {
types[node] = null;
overflowTypes[node] = true;
nodeChanged[node] = true;
changed = true;
break;
} else {
nodeChanged[node] = true;
changed = true;
}
}
}
}
@ -406,34 +451,92 @@ public class ClassInference {
private void propagateAlongCasts() {
for (ValueCast cast : casts) {
int toNode = nodeMapping[packNodeAndDegree(cast.toVariable, 0)];
if (overflowTypes[toNode]) {
continue;
}
int fromNode = nodeMapping[packNodeAndDegree(cast.fromVariable, 0)];
if (!formerNodeChanged[fromNode] && !nodeChanged[fromNode]) {
continue;
}
int toNode = nodeMapping[packNodeAndDegree(cast.toVariable, 0)];
IntHashSet targetTypes = getNodeTypes(toNode);
for (IntCursor cursor : types[fromNode]) {
if (targetTypes.contains(cursor.value)) {
continue;
if (overflowTypes[fromNode]) {
int degree = 0;
ValueType targetType = cast.targetType;
while (targetType instanceof ValueType.Array) {
targetType = ((ValueType.Array) targetType).getItemType();
++degree;
}
String className = typeList.get(cursor.value);
ValueType type;
if (className.startsWith("[")) {
type = ValueType.parseIfPossible(className);
if (type == null) {
type = ValueType.arrayOf(ValueType.object("java.lang.Object"));
if (targetType instanceof ValueType.Object) {
String targetClassName = ((ValueType.Object) targetType).getClassName();
List<? extends String> subclasses = subclassListProvider.getSubclasses(
targetClassName, degree > 0);
if (subclasses == null) {
types[toNode] = null;
overflowTypes[toNode] = true;
nodeChanged[toNode] = true;
changed = true;
} else {
for (String subclass : subclasses) {
if (degree > 0) {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < degree; ++i) {
sb.append('[');
}
subclass = sb.append('L').append(subclass.replace('.', '/')).append(';').toString();
}
int typeId = getTypeByName(subclass);
if (targetTypes.add(typeId)) {
changed = true;
nodeChanged[toNode] = true;
if (targetTypes.size() > overflowLimit) {
overflowTypes[toNode] = true;
types[toNode] = null;
break;
}
}
}
}
} else {
type = ValueType.object(className);
int typeId = getTypeByName(targetType.toString());
if (targetTypes.add(typeId)) {
changed = true;
nodeChanged[toNode] = true;
if (targetTypes.size() > overflowLimit) {
overflowTypes[toNode] = true;
types[toNode] = null;
}
}
}
} else {
for (IntCursor cursor : types[fromNode]) {
if (targetTypes.contains(cursor.value)) {
continue;
}
String className = typeList.get(cursor.value);
if (hierarchy.isSuperType(cast.targetType, type, false)) {
changed = true;
nodeChanged[toNode] = true;
targetTypes.add(cursor.value);
ValueType type;
if (className.startsWith("[")) {
type = ValueType.parseIfPossible(className);
if (type == null) {
type = ValueType.arrayOf(ValueType.object("java.lang.Object"));
}
} else {
type = ValueType.object(className);
}
if (hierarchy.isSuperType(cast.targetType, type, false) && targetTypes.add(cursor.value)) {
changed = true;
nodeChanged[toNode] = true;
if (targetTypes.size() > overflowLimit) {
overflowTypes[toNode] = true;
types[toNode] = null;
break;
}
}
}
}
}
@ -443,17 +546,49 @@ public class ClassInference {
ClassReaderSource classSource = dependencyInfo.getClassSource();
for (VirtualCallSite callSite : virtualCallSites) {
if (callSite.receiverOverflow) {
continue;
}
int instanceNode = nodeMapping[packNodeAndDegree(callSite.instance, 0)];
if (!formerNodeChanged[instanceNode] && !nodeChanged[instanceNode]) {
continue;
}
for (IntCursor type : types[instanceNode]) {
if (!callSite.knownClasses.add(type.value)) {
List<? extends String> receiverTypes;
if (overflowTypes[instanceNode]) {
callSite.receiverOverflow = true;
List<? extends String> subclasses = subclassListProvider.getSubclasses(
callSite.method.getClassName(), true);
if (subclasses != null) {
receiverTypes = subclasses;
} else {
List<? extends MethodReference> implementations = subclassListProvider.getMethods(
callSite.method.getDescriptor());
if (implementations != null) {
for (MethodReference methodReference : implementations) {
mountVirtualMethod(program, callSite, methodReference);
}
}
continue;
}
} else {
List<String> instanceNodeTypes = new ArrayList<>();
for (IntCursor type : types[instanceNode]) {
if (callSite.knownClasses.contains(type.value)) {
continue;
}
instanceNodeTypes.add(typeList.get(type.value));
}
receiverTypes = instanceNodeTypes;
}
for (String className : receiverTypes) {
int typeId = getTypeByName(className);
if (!callSite.knownClasses.add(typeId)) {
continue;
}
String className = typeList.get(type.value);
MethodReference rawMethod = new MethodReference(className, callSite.method.getDescriptor());
MethodReader resolvedMethod = classSource.resolveImplementation(rawMethod);
@ -462,28 +597,32 @@ public class ClassInference {
}
MethodReference resolvedMethodRef = resolvedMethod.getReference();
if (!callSite.resolvedMethods.add(resolvedMethodRef)) {
continue;
}
MethodDependencyInfo methodDep = dependencyInfo.getMethod(resolvedMethodRef);
if (methodDep == null) {
continue;
}
if (callSite.receiver >= 0) {
readValue(methodDep.getResult(), program.variableAt(callSite.receiver));
}
for (int i = 0; i < callSite.arguments.length; ++i) {
writeValue(methodDep.getVariable(i + 1), program.variableAt(callSite.arguments[i]));
}
for (String thrownTypeName : methodDep.getThrown().getTypes()) {
propagateException(thrownTypeName, program.basicBlockAt(callSite.block));
}
mountVirtualMethod(program, callSite, resolvedMethodRef);
}
}
}
private void mountVirtualMethod(Program program, VirtualCallSite callSite, MethodReference methodReference) {
if (!callSite.resolvedMethods.add(methodReference)) {
return;
}
MethodDependencyInfo methodDep = dependencyInfo.getMethod(methodReference);
if (methodDep == null) {
return;
}
if (callSite.receiver >= 0) {
readValue(methodDep.getResult(), program.variableAt(callSite.receiver));
}
for (int i = 0; i < callSite.arguments.length; ++i) {
writeValue(methodDep.getVariable(i + 1), program.variableAt(callSite.arguments[i]));
}
for (String thrownTypeName : methodDep.getThrown().getTypes()) {
propagateException(thrownTypeName, program.basicBlockAt(callSite.block));
}
}
private void propagateAlongExceptions(Program program) {
for (int i = 0; i < exceptions.length; i += 2) {
int variable = nodeMapping[packNodeAndDegree(exceptions[i], 0)];
@ -492,9 +631,13 @@ public class ClassInference {
}
BasicBlock block = program.basicBlockAt(exceptions[i + 1]);
for (IntCursor type : types[variable]) {
String typeName = typeList.get(type.value);
propagateException(typeName, block);
if (overflowTypes[variable]) {
propagateOverflowException(block);
} else {
for (IntCursor type : types[variable]) {
String typeName = typeList.get(type.value);
propagateException(typeName, block);
}
}
}
}
@ -508,10 +651,17 @@ public class ClassInference {
}
int exceptionNode = packNodeAndDegree(tryCatch.getHandler().getExceptionVariable().getIndex(), 0);
exceptionNode = nodeMapping[exceptionNode];
int thrownType = getTypeByName(thrownTypeName);
if (getNodeTypes(exceptionNode).add(thrownType)) {
nodeChanged[exceptionNode] = true;
changed = true;
if (!overflowTypes[exceptionNode]) {
int thrownType = getTypeByName(thrownTypeName);
IntHashSet nodeTypes = getNodeTypes(exceptionNode);
if (nodeTypes.add(thrownType)) {
nodeChanged[exceptionNode] = true;
changed = true;
if (nodeTypes.size() > overflowLimit) {
types[exceptionNode] = null;
overflowTypes[exceptionNode] = true;
}
}
}
break;
@ -519,6 +669,45 @@ public class ClassInference {
}
}
private void propagateOverflowException(BasicBlock block) {
for (TryCatchBlock tryCatch : block.getTryCatchBlocks()) {
Variable exceptionVar = tryCatch.getHandler().getExceptionVariable();
if (exceptionVar == null) {
continue;
}
int exceptionNode = packNodeAndDegree(exceptionVar.getIndex(), 0);
exceptionNode = nodeMapping[exceptionNode];
if (overflowTypes[exceptionNode]) {
continue;
}
String expectedType = tryCatch.getExceptionType();
List<? extends String> thrownTypes = subclassListProvider.getSubclasses(expectedType, false);
if (thrownTypes == null) {
if (!overflowTypes[exceptionNode]) {
overflowTypes[exceptionNode] = true;
changed = true;
nodeChanged[exceptionNode] = true;
}
} else {
IntHashSet nodeTypes = getNodeTypes(exceptionNode);
for (String thrownTypeName : thrownTypes) {
int thrownType = getTypeByName(thrownTypeName);
if (nodeTypes.add(thrownType)) {
nodeChanged[exceptionNode] = true;
changed = true;
if (nodeTypes.size() > overflowLimit) {
types[exceptionNode] = null;
overflowTypes[exceptionNode] = true;
break;
}
}
}
}
}
}
IntHashSet getNodeTypes(int node) {
IntHashSet result = types[node];
if (result == null) {
@ -651,9 +840,7 @@ public class ClassInference {
int resultDegree = 0;
while (resultDependency.hasArrayType() && resultDegree <= MAX_DEGREE) {
resultDependency = resultDependency.getArrayItem();
for (String paramType : resultDependency.getTypes()) {
addType(insn.getValueToReturn().getIndex(), resultDegree, paramType);
}
addTypesFrom(insn.getValueToReturn().getIndex(), resultDegree, resultDependency);
++resultDegree;
}
}
@ -706,6 +893,7 @@ public class ClassInference {
static class VirtualCallSite {
int instance;
boolean receiverOverflow;
IntSet knownClasses = new IntHashSet();
Set<MethodReference> resolvedMethods = new HashSet<>();
MethodReference method;
@ -718,9 +906,7 @@ public class ClassInference {
int depth = 0;
boolean hasArrayType;
do {
for (String type : valueDep.getTypes()) {
addType(receiver.getIndex(), depth, type);
}
addTypesFrom(receiver.getIndex(), depth, valueDep);
depth++;
hasArrayType = valueDep.hasArrayType();
valueDep = valueDep.getArrayItem();
@ -732,18 +918,58 @@ public class ClassInference {
while (valueDep.hasArrayType() && depth < MAX_DEGREE) {
depth++;
valueDep = valueDep.getArrayItem();
for (String type : valueDep.getTypes()) {
addType(source.getIndex(), depth, type);
addTypesFrom(source.getIndex(), depth, valueDep);
}
}
boolean addType(int variable, int degree, String typeName) {
return addTypeImpl(packNodeAndDegree(variable, degree), getTypeByName(typeName));
}
boolean addTypeImpl(int node, int typeId) {
if (overflowTypes[node]) {
return true;
}
IntHashSet nodeTypes = getNodeTypes(node);
if (nodeTypes.add(typeId)) {
nodeChanged[node] = true;
changed = true;
if (nodeTypes.size() > overflowLimit) {
types[node] = null;
overflowTypes[node] = true;
return true;
}
}
return false;
}
void addTypesFrom(int variable, int degree, ValueDependencyInfo dep) {
if (overflowTypes[packNodeAndDegree(variable, degree)]) {
return;
}
if (dep.hasMoreTypesThan(overflowLimit)) {
overflowType(variable, degree);
} else {
String[] types = dep.getTypes();
for (String type : dep.getTypes()) {
if (addType(variable, degree, type)) {
break;
}
}
}
}
void addType(int variable, int degree, String typeName) {
void overflowType(int variable, int degree) {
int entry = nodeMapping[packNodeAndDegree(variable, degree)];
if (getNodeTypes(entry).add(getTypeByName(typeName))) {
nodeChanged[entry] = true;
changed = true;
if (overflowTypes[entry]) {
return;
}
overflowTypes[entry] = true;
nodeChanged[entry] = true;
changed = true;
}
static int extractNode(int nodeAndDegree) {
@ -758,18 +984,6 @@ public class ClassInference {
return (node << 3) | degree;
}
static long packTwoIntegers(int a, int b) {
return ((long) a << 32) | b;
}
static int unpackFirst(long pair) {
return (int) (pair >>> 32);
}
static int unpackSecond(long pair) {
return (int) pair;
}
static final class ValueCast {
final int fromVariable;
final int toVariable;

View File

@ -0,0 +1,175 @@
/*
* Copyright 2019 konsoletyper.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.teavm.model.analysis;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.teavm.model.AccessLevel;
import org.teavm.model.ClassReader;
import org.teavm.model.ClassReaderSource;
import org.teavm.model.ElementModifier;
import org.teavm.model.MethodDescriptor;
import org.teavm.model.MethodReader;
import org.teavm.model.MethodReference;
class SubclassListProvider {
private Map<String, ClassInfo> classes = new HashMap<>();
private Map<MethodDescriptor, List<MethodReference>> methodImplementations = new HashMap<>();
private int limit;
SubclassListProvider(ClassReaderSource classSource, Iterable<? extends String> classNames, int limit) {
this.limit = limit;
for (String className : classNames) {
registerClass(classSource, className);
}
}
private ClassInfo registerClass(ClassReaderSource classSource, String className) {
ClassInfo classInfo = classes.get(className);
if (classInfo == null) {
classInfo = new ClassInfo();
classes.put(className, classInfo);
ClassReader cls = classSource.get(className);
if (cls != null) {
if (!cls.hasModifier(ElementModifier.INTERFACE) && !cls.hasModifier(ElementModifier.ABSTRACT)) {
classInfo.concrete = true;
}
increaseClassCount(classSource, className, new HashSet<>(), classInfo.concrete);
if (cls.getParent() != null) {
ClassInfo parentInfo = registerClass(classSource, cls.getParent());
if (parentInfo.directSubclasses == null) {
parentInfo.directSubclasses = new ArrayList<>();
}
parentInfo.directSubclasses.add(className);
}
for (String itf : cls.getInterfaces()) {
ClassInfo parentInfo = registerClass(classSource, itf);
if (parentInfo.directSubclasses == null) {
parentInfo.directSubclasses = new ArrayList<>();
}
parentInfo.directSubclasses.add(className);
}
for (MethodReader method : cls.getMethods()) {
if (method.hasModifier(ElementModifier.STATIC)
|| method.hasModifier(ElementModifier.ABSTRACT)
|| method.getLevel() == AccessLevel.PRIVATE) {
continue;
}
List<MethodReference> implementations = methodImplementations.get(method.getDescriptor());
if (implementations == null) {
implementations = new ArrayList<>();
methodImplementations.put(method.getDescriptor(), implementations);
}
implementations.add(method.getReference());
}
}
}
return classInfo;
}
private void increaseClassCount(ClassReaderSource classSource, String className, Set<String> visited,
boolean concrete) {
if (!visited.add(className)) {
return;
}
ClassInfo classInfo = registerClass(classSource, className);
if ((!concrete || classInfo.concreteCount > limit) && classInfo.count > limit) {
return;
}
classInfo.count++;
if (concrete) {
classInfo.concreteCount++;
}
ClassReader cls = classSource.get(className);
if (cls != null) {
if (cls.getParent() != null) {
increaseClassCount(classSource, cls.getParent(), visited, concrete);
}
for (String itf : cls.getInterfaces()) {
increaseClassCount(classSource, itf, visited, concrete);
}
}
}
List<? extends String> getSubclasses(String className, boolean includeAbstract) {
ClassInfo classInfo = classes.get(className);
if (classInfo == null) {
return null;
}
if (includeAbstract) {
if (classInfo.count > limit) {
return null;
}
} else {
if (classInfo.concreteCount > limit) {
return null;
}
}
String[] result = new String[includeAbstract ? classInfo.count : classInfo.concreteCount];
collectSubclasses(className, result, 0, new HashSet<>(), includeAbstract);
return Arrays.asList(result);
}
List<? extends MethodReference> getMethods(MethodDescriptor descriptor) {
return methodImplementations.get(descriptor);
}
private int collectSubclasses(String className, String[] consumer, int index, Set<String> visited,
boolean includeAbstract) {
if (!visited.add(className)) {
return index;
}
ClassInfo classInfo = classes.get(className);
if (classInfo == null) {
return index;
}
if (includeAbstract || classInfo.concrete) {
consumer[index++] = className;
}
if (classInfo.directSubclasses != null) {
for (String subclassName : classInfo.directSubclasses) {
index = collectSubclasses(subclassName, consumer, index, visited, includeAbstract);
}
}
return index;
}
static class ClassInfo {
int count;
int concreteCount;
boolean concrete;
List<String> directSubclasses;
}
}

View File

@ -69,6 +69,7 @@ public class Inlining {
private MethodUsageCounter usageCounter;
private Set<MethodReference> methodsUsedOnce = new HashSet<>();
private boolean devirtualization;
private ClassInference classInference;
public Inlining(ClassHierarchy hierarchy, DependencyInfo dependencyInfo, InliningStrategy strategy,
ListableClassReaderSource classes, Predicate<MethodReference> externalMethods,
@ -393,8 +394,10 @@ public class Inlining {
}
private void devirtualize(Program program, MethodReference method, DependencyInfo dependencyInfo) {
ClassInference inference = new ClassInference(dependencyInfo, hierarchy);
inference.infer(program, method);
if (classInference == null) {
classInference = new ClassInference(dependencyInfo, hierarchy, classes.getClassNames(), 30);
}
classInference.infer(program, method);
for (BasicBlock block : program.getBasicBlocks()) {
for (Instruction instruction : block) {
@ -407,11 +410,19 @@ public class Inlining {
}
Set<MethodReference> implementations = new HashSet<>();
for (String className : inference.classesOf(invoke.getInstance().getIndex())) {
MethodReference rawMethod = new MethodReference(className, invoke.getMethod().getDescriptor());
MethodReader resolvedMethod = dependencyInfo.getClassSource().resolveImplementation(rawMethod);
if (resolvedMethod != null) {
implementations.add(resolvedMethod.getReference());
if (classInference.isOverflow(invoke.getInstance().getIndex())) {
List<? extends MethodReference> knownImplementations = classInference.getMethodImplementations(
invoke.getMethod().getDescriptor());
if (knownImplementations != null) {
implementations.addAll(knownImplementations);
}
} else {
for (String className : classInference.classesOf(invoke.getInstance().getIndex())) {
MethodReference rawMethod = new MethodReference(className, invoke.getMethod().getDescriptor());
MethodReader resolvedMethod = dependencyInfo.getClassSource().resolveImplementation(rawMethod);
if (resolvedMethod != null) {
implementations.add(resolvedMethod.getReference());
}
}
}
@ -423,7 +434,7 @@ public class Inlining {
}
}
private class PlanEntry {
static class PlanEntry {
int targetBlock;
Instruction targetInstruction;
MethodReference method;

View File

@ -143,13 +143,13 @@ public class DependencyTest {
MethodHolder method = classSource.get(testMethod.getClassName()).getMethod(testMethod.getDescriptor());
List<Assertion> assertions = collectAssertions(method);
processAssertions(assertions, vm.getDependencyInfo().getMethod(testMethod), vm.getDependencyInfo(),
method.getProgram());
method.getProgram(), vm.getClasses());
}
private void processAssertions(List<Assertion> assertions, MethodDependencyInfo methodDep,
DependencyInfo dependencyInfo, Program program) {
DependencyInfo dependencyInfo, Program program, Iterable<? extends String> classNames) {
ClassInference classInference = new ClassInference(dependencyInfo, new ClassHierarchy(
dependencyInfo.getClassSource()));
dependencyInfo.getClassSource()), classNames, 10);
classInference.infer(program, methodDep.getReference());
for (Assertion assertion : assertions) {
@ -160,10 +160,12 @@ public class DependencyTest {
Arrays.sort(expectedTypes);
Assert.assertArrayEquals("Assertion at " + assertion.location, expectedTypes, actualTypes);
actualTypes = classInference.classesOf(assertion.value);
Arrays.sort(actualTypes);
Assert.assertArrayEquals("Assertion at " + assertion.location + " (class inference)",
expectedTypes, actualTypes);
if (!classInference.isOverflow(assertion.value)) {
Set<String> actualTypeSet = new HashSet<>(Arrays.asList(classInference.classesOf(assertion.value)));
Assert.assertTrue("Assertion at " + assertion.location + " (class inference), "
+ "expected: " + Arrays.toString(expectedTypes) + ", actual: " + actualTypeSet,
actualTypeSet.containsAll(Arrays.asList(expectedTypes)));
}
}
}