From a2a9dbcbe35f2118ea58815113db0c456669fb0a Mon Sep 17 00:00:00 2001 From: Alexey Andreev Date: Thu, 26 Sep 2019 19:40:23 +0300 Subject: [PATCH] Improve performance of inliner in FULL optimization level The former implementation suffered from class inference. The reason was in many nodes having too many possible classes in them. The new implementation does not handle full set of classes in each node. Instead, it introduces concept of 'overflow', i.e. node having more types than the given upper limit. These nodes behave as if there were all possible classes in them, which allows to apply certain optimization for these nodes and omit heavy computations of large type sets. --- .../org/teavm/dependency/DependencyNode.java | 8 + .../java/org/teavm/dependency/TypeSet.java | 36 ++ .../teavm/dependency/ValueDependencyInfo.java | 2 + .../teavm/model/analysis/ClassInference.java | 372 ++++++++++++++---- .../model/analysis/SubclassListProvider.java | 175 ++++++++ .../teavm/model/optimization/Inlining.java | 27 +- .../org/teavm/dependency/DependencyTest.java | 16 +- 7 files changed, 542 insertions(+), 94 deletions(-) create mode 100644 core/src/main/java/org/teavm/model/analysis/SubclassListProvider.java diff --git a/core/src/main/java/org/teavm/dependency/DependencyNode.java b/core/src/main/java/org/teavm/dependency/DependencyNode.java index 30683cc3f..d5e645c31 100644 --- a/core/src/main/java/org/teavm/dependency/DependencyNode.java +++ b/core/src/main/java/org/teavm/dependency/DependencyNode.java @@ -434,6 +434,14 @@ public class DependencyNode implements ValueDependencyInfo { return i == result.length ? result : Arrays.copyOf(result, i); } + @Override + public boolean hasMoreTypesThan(int limit) { + if (typeSet == null) { + return false; + } + return typeSet.hasMoreTypesThan(limit, typeFilter != null ? getFilter()::match : null); + } + DependencyType[] getTypesInternal() { if (typeSet == null) { return new DependencyType[0]; diff --git a/core/src/main/java/org/teavm/dependency/TypeSet.java b/core/src/main/java/org/teavm/dependency/TypeSet.java index ef83804c0..34e9c1bef 100644 --- a/core/src/main/java/org/teavm/dependency/TypeSet.java +++ b/core/src/main/java/org/teavm/dependency/TypeSet.java @@ -23,6 +23,7 @@ import java.util.BitSet; import java.util.LinkedHashSet; import java.util.List; import java.util.Set; +import java.util.function.Predicate; class TypeSet { private static final int SMALL_TYPES_THRESHOLD = 3; @@ -89,6 +90,41 @@ class TypeSet { } } + boolean hasMoreTypesThan(int limit, Predicate filter) { + if (this.types != null) { + if (filter == null) { + return this.types.cardinality() > limit; + } + for (int index = this.types.nextSetBit(0); index >= 0; index = this.types.nextSetBit(index + 1)) { + DependencyType type = dependencyAnalyzer.types.get(index); + if (filter.test(type)) { + if (--limit < 0) { + return true; + } + } + } + return false; + } else if (this.smallTypes != null) { + if (this.smallTypes.length <= limit) { + return false; + } + if (filter == null) { + return true; + } + for (int i = 0; i < smallTypes.length; ++i) { + DependencyType type = dependencyAnalyzer.types.get(smallTypes[i]); + if (filter.test(type)) { + if (--limit < 0) { + return true; + } + } + } + return false; + } else { + return false; + } + } + DependencyType[] getTypesForNode(DependencyNode sourceNode, DependencyNode targetNode, DependencyTypeFilter filter) { int j = 0; diff --git a/core/src/main/java/org/teavm/dependency/ValueDependencyInfo.java b/core/src/main/java/org/teavm/dependency/ValueDependencyInfo.java index 6dfdab0a8..5aa96ecd3 100644 --- a/core/src/main/java/org/teavm/dependency/ValueDependencyInfo.java +++ b/core/src/main/java/org/teavm/dependency/ValueDependencyInfo.java @@ -20,6 +20,8 @@ public interface ValueDependencyInfo { boolean hasType(String type); + boolean hasMoreTypesThan(int limit); + boolean hasArrayType(); ValueDependencyInfo getArrayItem(); diff --git a/core/src/main/java/org/teavm/model/analysis/ClassInference.java b/core/src/main/java/org/teavm/model/analysis/ClassInference.java index 8d250dded..6f02869b2 100644 --- a/core/src/main/java/org/teavm/model/analysis/ClassInference.java +++ b/core/src/main/java/org/teavm/model/analysis/ClassInference.java @@ -39,6 +39,7 @@ import org.teavm.model.ClassHierarchy; import org.teavm.model.ClassReaderSource; import org.teavm.model.Incoming; import org.teavm.model.Instruction; +import org.teavm.model.MethodDescriptor; import org.teavm.model.MethodReader; import org.teavm.model.MethodReference; import org.teavm.model.Phi; @@ -68,6 +69,7 @@ import org.teavm.model.instructions.UnwrapArrayInstruction; public class ClassInference { private DependencyInfo dependencyInfo; private ClassHierarchy hierarchy; + private SubclassListProvider subclassListProvider; private Graph assignmentGraph; private Graph cloneGraph; private Graph arrayGraph; @@ -76,11 +78,13 @@ public class ClassInference { private ValueCast[] casts; private int[] exceptions; private VirtualCallSite[] virtualCallSites; + private int overflowLimit; private int[] propagationPath; private int[] nodeMapping; private IntHashSet[] types; + private boolean[] overflowTypes; private ObjectIntMap typeMap = new ObjectIntHashMap<>(); private List typeList = new ArrayList<>(); @@ -90,9 +94,12 @@ public class ClassInference { private static final int MAX_DEGREE = 3; - public ClassInference(DependencyInfo dependencyInfo, ClassHierarchy hierarchy) { + public ClassInference(DependencyInfo dependencyInfo, ClassHierarchy hierarchy, + Iterable classNames, int overflowLimit) { this.dependencyInfo = dependencyInfo; this.hierarchy = hierarchy; + this.overflowLimit = overflowLimit; + subclassListProvider = new SubclassListProvider(dependencyInfo.getClassSource(), classNames, overflowLimit); } public void infer(Program program, MethodReference methodReference) { @@ -112,6 +119,7 @@ public class ClassInference { */ types = new IntHashSet[program.variableCount() << 3]; + overflowTypes = new boolean[program.variableCount() << 3]; nodeChanged = new boolean[types.length]; formerNodeChanged = new boolean[nodeChanged.length]; nodeMapping = new int[types.length]; @@ -129,9 +137,7 @@ public class ClassInference { if (paramDep != null) { int degree = 0; while (degree <= MAX_DEGREE) { - for (String paramType : paramDep.getTypes()) { - addType(i, degree, paramType); - } + addTypesFrom(i, degree, paramDep); if (!paramDep.hasArrayType()) { break; } @@ -166,6 +172,14 @@ public class ClassInference { nodeChanged = null; } + public boolean isOverflow(int variableIndex) { + return overflowTypes[nodeMapping[packNodeAndDegree(variableIndex, 0)]]; + } + + public List getMethodImplementations(MethodDescriptor descriptor) { + return subclassListProvider.getMethods(descriptor); + } + public String[] classesOf(int variableIndex) { IntHashSet typeSet = types[nodeMapping[packNodeAndDegree(variableIndex, 0)]]; if (typeSet == null) { @@ -210,7 +224,7 @@ public class ClassInference { IntStack stack = new IntStack(); for (int i = 0; i < types.length; ++i) { - if (types[i] != null) { + if (types[i] != null || overflowTypes[i]) { stack.push(i); } } @@ -286,8 +300,10 @@ public class ClassInference { boolean[] nodeChangedBackup = nodeChanged.clone(); IntHashSet[] typesBackup = types.clone(); + boolean[] overflowTypesBackup = overflowTypes.clone(); Arrays.fill(nodeChanged, false); Arrays.fill(types, null); + Arrays.fill(overflowTypes, false); GraphBuilder graphBuilder = new GraphBuilder(graph.size()); for (int i = 0; i < graph.size(); ++i) { @@ -300,8 +316,17 @@ public class ClassInference { } int node = nodeMapping[i]; - if (typesBackup[i] != null) { - getNodeTypes(node).addAll(typesBackup[i]); + if (overflowTypesBackup[i]) { + overflowTypes[i] = true; + types[node] = null; + } else if (typesBackup[i] != null && !overflowTypes[i]) { + IntHashSet nodeTypes = getNodeTypes(node); + if (nodeTypes.addAll(typesBackup[i]) > 0) { + if (nodeTypes.size() > overflowLimit) { + types[node] = null; + overflowTypes[node] = true; + } + } } if (nodeChangedBackup[i]) { @@ -323,7 +348,7 @@ public class ClassInference { IntStack stack = new IntStack(); for (int i = 0; i < graph.size(); ++i) { - if (graph.incomingEdgesCount(i) == 0 && types[i] != null) { + if (graph.incomingEdgesCount(i) == 0 && (overflowTypes[i] || types[i] != null)) { stack.push(i); } } @@ -381,6 +406,10 @@ public class ClassInference { private void propagateAlongDAG() { for (int i = propagationPath.length - 1; i >= 0; --i) { int node = propagationPath[i]; + if (overflowTypes[node]) { + continue; + } + boolean predecessorsChanged = false; for (int predecessor : graph.incomingEdges(node)) { if (formerNodeChanged[predecessor] || nodeChanged[predecessor]) { @@ -395,9 +424,25 @@ public class ClassInference { IntHashSet nodeTypes = getNodeTypes(node); for (int predecessor : graph.incomingEdges(node)) { if (formerNodeChanged[predecessor] || nodeChanged[predecessor]) { - if (nodeTypes.addAll(types[predecessor]) > 0) { - nodeChanged[node] = true; - changed = true; + if (overflowTypes[predecessor]) { + if (!overflowTypes[node]) { + overflowTypes[node] = true; + types[node] = null; + changed = true; + nodeChanged[node] = true; + break; + } + } else if (types[predecessor] != null && nodeTypes.addAll(types[predecessor]) > 0) { + if (nodeTypes.size() > overflowLimit) { + types[node] = null; + overflowTypes[node] = true; + nodeChanged[node] = true; + changed = true; + break; + } else { + nodeChanged[node] = true; + changed = true; + } } } } @@ -406,34 +451,92 @@ public class ClassInference { private void propagateAlongCasts() { for (ValueCast cast : casts) { + int toNode = nodeMapping[packNodeAndDegree(cast.toVariable, 0)]; + if (overflowTypes[toNode]) { + continue; + } + int fromNode = nodeMapping[packNodeAndDegree(cast.fromVariable, 0)]; if (!formerNodeChanged[fromNode] && !nodeChanged[fromNode]) { continue; } - int toNode = nodeMapping[packNodeAndDegree(cast.toVariable, 0)]; IntHashSet targetTypes = getNodeTypes(toNode); - for (IntCursor cursor : types[fromNode]) { - if (targetTypes.contains(cursor.value)) { - continue; + if (overflowTypes[fromNode]) { + int degree = 0; + ValueType targetType = cast.targetType; + while (targetType instanceof ValueType.Array) { + targetType = ((ValueType.Array) targetType).getItemType(); + ++degree; } - String className = typeList.get(cursor.value); - - ValueType type; - if (className.startsWith("[")) { - type = ValueType.parseIfPossible(className); - if (type == null) { - type = ValueType.arrayOf(ValueType.object("java.lang.Object")); + if (targetType instanceof ValueType.Object) { + String targetClassName = ((ValueType.Object) targetType).getClassName(); + List subclasses = subclassListProvider.getSubclasses( + targetClassName, degree > 0); + if (subclasses == null) { + types[toNode] = null; + overflowTypes[toNode] = true; + nodeChanged[toNode] = true; + changed = true; + } else { + for (String subclass : subclasses) { + if (degree > 0) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < degree; ++i) { + sb.append('['); + } + subclass = sb.append('L').append(subclass.replace('.', '/')).append(';').toString(); + } + int typeId = getTypeByName(subclass); + if (targetTypes.add(typeId)) { + changed = true; + nodeChanged[toNode] = true; + if (targetTypes.size() > overflowLimit) { + overflowTypes[toNode] = true; + types[toNode] = null; + break; + } + } + } } } else { - type = ValueType.object(className); + int typeId = getTypeByName(targetType.toString()); + if (targetTypes.add(typeId)) { + changed = true; + nodeChanged[toNode] = true; + if (targetTypes.size() > overflowLimit) { + overflowTypes[toNode] = true; + types[toNode] = null; + } + } } + } else { + for (IntCursor cursor : types[fromNode]) { + if (targetTypes.contains(cursor.value)) { + continue; + } + String className = typeList.get(cursor.value); - if (hierarchy.isSuperType(cast.targetType, type, false)) { - changed = true; - nodeChanged[toNode] = true; - targetTypes.add(cursor.value); + ValueType type; + if (className.startsWith("[")) { + type = ValueType.parseIfPossible(className); + if (type == null) { + type = ValueType.arrayOf(ValueType.object("java.lang.Object")); + } + } else { + type = ValueType.object(className); + } + + if (hierarchy.isSuperType(cast.targetType, type, false) && targetTypes.add(cursor.value)) { + changed = true; + nodeChanged[toNode] = true; + if (targetTypes.size() > overflowLimit) { + overflowTypes[toNode] = true; + types[toNode] = null; + break; + } + } } } } @@ -443,17 +546,49 @@ public class ClassInference { ClassReaderSource classSource = dependencyInfo.getClassSource(); for (VirtualCallSite callSite : virtualCallSites) { + if (callSite.receiverOverflow) { + continue; + } + int instanceNode = nodeMapping[packNodeAndDegree(callSite.instance, 0)]; if (!formerNodeChanged[instanceNode] && !nodeChanged[instanceNode]) { continue; } - for (IntCursor type : types[instanceNode]) { - if (!callSite.knownClasses.add(type.value)) { + List receiverTypes; + if (overflowTypes[instanceNode]) { + callSite.receiverOverflow = true; + List subclasses = subclassListProvider.getSubclasses( + callSite.method.getClassName(), true); + if (subclasses != null) { + receiverTypes = subclasses; + } else { + List implementations = subclassListProvider.getMethods( + callSite.method.getDescriptor()); + if (implementations != null) { + for (MethodReference methodReference : implementations) { + mountVirtualMethod(program, callSite, methodReference); + } + } + continue; + } + } else { + List instanceNodeTypes = new ArrayList<>(); + for (IntCursor type : types[instanceNode]) { + if (callSite.knownClasses.contains(type.value)) { + continue; + } + instanceNodeTypes.add(typeList.get(type.value)); + } + receiverTypes = instanceNodeTypes; + } + + for (String className : receiverTypes) { + int typeId = getTypeByName(className); + if (!callSite.knownClasses.add(typeId)) { continue; } - String className = typeList.get(type.value); MethodReference rawMethod = new MethodReference(className, callSite.method.getDescriptor()); MethodReader resolvedMethod = classSource.resolveImplementation(rawMethod); @@ -462,28 +597,32 @@ public class ClassInference { } MethodReference resolvedMethodRef = resolvedMethod.getReference(); - if (!callSite.resolvedMethods.add(resolvedMethodRef)) { - continue; - } - - MethodDependencyInfo methodDep = dependencyInfo.getMethod(resolvedMethodRef); - if (methodDep == null) { - continue; - } - if (callSite.receiver >= 0) { - readValue(methodDep.getResult(), program.variableAt(callSite.receiver)); - } - for (int i = 0; i < callSite.arguments.length; ++i) { - writeValue(methodDep.getVariable(i + 1), program.variableAt(callSite.arguments[i])); - } - - for (String thrownTypeName : methodDep.getThrown().getTypes()) { - propagateException(thrownTypeName, program.basicBlockAt(callSite.block)); - } + mountVirtualMethod(program, callSite, resolvedMethodRef); } } } + private void mountVirtualMethod(Program program, VirtualCallSite callSite, MethodReference methodReference) { + if (!callSite.resolvedMethods.add(methodReference)) { + return; + } + + MethodDependencyInfo methodDep = dependencyInfo.getMethod(methodReference); + if (methodDep == null) { + return; + } + if (callSite.receiver >= 0) { + readValue(methodDep.getResult(), program.variableAt(callSite.receiver)); + } + for (int i = 0; i < callSite.arguments.length; ++i) { + writeValue(methodDep.getVariable(i + 1), program.variableAt(callSite.arguments[i])); + } + + for (String thrownTypeName : methodDep.getThrown().getTypes()) { + propagateException(thrownTypeName, program.basicBlockAt(callSite.block)); + } + } + private void propagateAlongExceptions(Program program) { for (int i = 0; i < exceptions.length; i += 2) { int variable = nodeMapping[packNodeAndDegree(exceptions[i], 0)]; @@ -492,9 +631,13 @@ public class ClassInference { } BasicBlock block = program.basicBlockAt(exceptions[i + 1]); - for (IntCursor type : types[variable]) { - String typeName = typeList.get(type.value); - propagateException(typeName, block); + if (overflowTypes[variable]) { + propagateOverflowException(block); + } else { + for (IntCursor type : types[variable]) { + String typeName = typeList.get(type.value); + propagateException(typeName, block); + } } } } @@ -508,10 +651,17 @@ public class ClassInference { } int exceptionNode = packNodeAndDegree(tryCatch.getHandler().getExceptionVariable().getIndex(), 0); exceptionNode = nodeMapping[exceptionNode]; - int thrownType = getTypeByName(thrownTypeName); - if (getNodeTypes(exceptionNode).add(thrownType)) { - nodeChanged[exceptionNode] = true; - changed = true; + if (!overflowTypes[exceptionNode]) { + int thrownType = getTypeByName(thrownTypeName); + IntHashSet nodeTypes = getNodeTypes(exceptionNode); + if (nodeTypes.add(thrownType)) { + nodeChanged[exceptionNode] = true; + changed = true; + if (nodeTypes.size() > overflowLimit) { + types[exceptionNode] = null; + overflowTypes[exceptionNode] = true; + } + } } break; @@ -519,6 +669,45 @@ public class ClassInference { } } + private void propagateOverflowException(BasicBlock block) { + for (TryCatchBlock tryCatch : block.getTryCatchBlocks()) { + Variable exceptionVar = tryCatch.getHandler().getExceptionVariable(); + if (exceptionVar == null) { + continue; + } + + int exceptionNode = packNodeAndDegree(exceptionVar.getIndex(), 0); + exceptionNode = nodeMapping[exceptionNode]; + if (overflowTypes[exceptionNode]) { + continue; + } + + String expectedType = tryCatch.getExceptionType(); + List thrownTypes = subclassListProvider.getSubclasses(expectedType, false); + if (thrownTypes == null) { + if (!overflowTypes[exceptionNode]) { + overflowTypes[exceptionNode] = true; + changed = true; + nodeChanged[exceptionNode] = true; + } + } else { + IntHashSet nodeTypes = getNodeTypes(exceptionNode); + for (String thrownTypeName : thrownTypes) { + int thrownType = getTypeByName(thrownTypeName); + if (nodeTypes.add(thrownType)) { + nodeChanged[exceptionNode] = true; + changed = true; + if (nodeTypes.size() > overflowLimit) { + types[exceptionNode] = null; + overflowTypes[exceptionNode] = true; + break; + } + } + } + } + } + } + IntHashSet getNodeTypes(int node) { IntHashSet result = types[node]; if (result == null) { @@ -651,9 +840,7 @@ public class ClassInference { int resultDegree = 0; while (resultDependency.hasArrayType() && resultDegree <= MAX_DEGREE) { resultDependency = resultDependency.getArrayItem(); - for (String paramType : resultDependency.getTypes()) { - addType(insn.getValueToReturn().getIndex(), resultDegree, paramType); - } + addTypesFrom(insn.getValueToReturn().getIndex(), resultDegree, resultDependency); ++resultDegree; } } @@ -706,6 +893,7 @@ public class ClassInference { static class VirtualCallSite { int instance; + boolean receiverOverflow; IntSet knownClasses = new IntHashSet(); Set resolvedMethods = new HashSet<>(); MethodReference method; @@ -718,9 +906,7 @@ public class ClassInference { int depth = 0; boolean hasArrayType; do { - for (String type : valueDep.getTypes()) { - addType(receiver.getIndex(), depth, type); - } + addTypesFrom(receiver.getIndex(), depth, valueDep); depth++; hasArrayType = valueDep.hasArrayType(); valueDep = valueDep.getArrayItem(); @@ -732,18 +918,58 @@ public class ClassInference { while (valueDep.hasArrayType() && depth < MAX_DEGREE) { depth++; valueDep = valueDep.getArrayItem(); - for (String type : valueDep.getTypes()) { - addType(source.getIndex(), depth, type); + addTypesFrom(source.getIndex(), depth, valueDep); + } + } + + boolean addType(int variable, int degree, String typeName) { + return addTypeImpl(packNodeAndDegree(variable, degree), getTypeByName(typeName)); + } + + boolean addTypeImpl(int node, int typeId) { + if (overflowTypes[node]) { + return true; + } + + IntHashSet nodeTypes = getNodeTypes(node); + if (nodeTypes.add(typeId)) { + nodeChanged[node] = true; + changed = true; + if (nodeTypes.size() > overflowLimit) { + types[node] = null; + overflowTypes[node] = true; + return true; + } + } + + return false; + } + + void addTypesFrom(int variable, int degree, ValueDependencyInfo dep) { + if (overflowTypes[packNodeAndDegree(variable, degree)]) { + return; + } + if (dep.hasMoreTypesThan(overflowLimit)) { + overflowType(variable, degree); + } else { + String[] types = dep.getTypes(); + for (String type : dep.getTypes()) { + if (addType(variable, degree, type)) { + break; + } } } } - void addType(int variable, int degree, String typeName) { + void overflowType(int variable, int degree) { int entry = nodeMapping[packNodeAndDegree(variable, degree)]; - if (getNodeTypes(entry).add(getTypeByName(typeName))) { - nodeChanged[entry] = true; - changed = true; + if (overflowTypes[entry]) { + return; } + + overflowTypes[entry] = true; + nodeChanged[entry] = true; + changed = true; } static int extractNode(int nodeAndDegree) { @@ -758,18 +984,6 @@ public class ClassInference { return (node << 3) | degree; } - static long packTwoIntegers(int a, int b) { - return ((long) a << 32) | b; - } - - static int unpackFirst(long pair) { - return (int) (pair >>> 32); - } - - static int unpackSecond(long pair) { - return (int) pair; - } - static final class ValueCast { final int fromVariable; final int toVariable; diff --git a/core/src/main/java/org/teavm/model/analysis/SubclassListProvider.java b/core/src/main/java/org/teavm/model/analysis/SubclassListProvider.java new file mode 100644 index 000000000..3b6772cea --- /dev/null +++ b/core/src/main/java/org/teavm/model/analysis/SubclassListProvider.java @@ -0,0 +1,175 @@ +/* + * Copyright 2019 konsoletyper. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.teavm.model.analysis; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.teavm.model.AccessLevel; +import org.teavm.model.ClassReader; +import org.teavm.model.ClassReaderSource; +import org.teavm.model.ElementModifier; +import org.teavm.model.MethodDescriptor; +import org.teavm.model.MethodReader; +import org.teavm.model.MethodReference; + +class SubclassListProvider { + private Map classes = new HashMap<>(); + private Map> methodImplementations = new HashMap<>(); + private int limit; + + SubclassListProvider(ClassReaderSource classSource, Iterable classNames, int limit) { + this.limit = limit; + + for (String className : classNames) { + registerClass(classSource, className); + } + } + + private ClassInfo registerClass(ClassReaderSource classSource, String className) { + ClassInfo classInfo = classes.get(className); + if (classInfo == null) { + classInfo = new ClassInfo(); + classes.put(className, classInfo); + + ClassReader cls = classSource.get(className); + if (cls != null) { + if (!cls.hasModifier(ElementModifier.INTERFACE) && !cls.hasModifier(ElementModifier.ABSTRACT)) { + classInfo.concrete = true; + } + increaseClassCount(classSource, className, new HashSet<>(), classInfo.concrete); + + if (cls.getParent() != null) { + ClassInfo parentInfo = registerClass(classSource, cls.getParent()); + if (parentInfo.directSubclasses == null) { + parentInfo.directSubclasses = new ArrayList<>(); + } + parentInfo.directSubclasses.add(className); + } + + for (String itf : cls.getInterfaces()) { + ClassInfo parentInfo = registerClass(classSource, itf); + if (parentInfo.directSubclasses == null) { + parentInfo.directSubclasses = new ArrayList<>(); + } + parentInfo.directSubclasses.add(className); + } + + for (MethodReader method : cls.getMethods()) { + if (method.hasModifier(ElementModifier.STATIC) + || method.hasModifier(ElementModifier.ABSTRACT) + || method.getLevel() == AccessLevel.PRIVATE) { + continue; + } + List implementations = methodImplementations.get(method.getDescriptor()); + if (implementations == null) { + implementations = new ArrayList<>(); + methodImplementations.put(method.getDescriptor(), implementations); + } + implementations.add(method.getReference()); + } + } + } + + return classInfo; + } + + private void increaseClassCount(ClassReaderSource classSource, String className, Set visited, + boolean concrete) { + if (!visited.add(className)) { + return; + } + + ClassInfo classInfo = registerClass(classSource, className); + if ((!concrete || classInfo.concreteCount > limit) && classInfo.count > limit) { + return; + } + classInfo.count++; + if (concrete) { + classInfo.concreteCount++; + } + + ClassReader cls = classSource.get(className); + if (cls != null) { + if (cls.getParent() != null) { + increaseClassCount(classSource, cls.getParent(), visited, concrete); + } + for (String itf : cls.getInterfaces()) { + increaseClassCount(classSource, itf, visited, concrete); + } + } + } + + List getSubclasses(String className, boolean includeAbstract) { + ClassInfo classInfo = classes.get(className); + if (classInfo == null) { + return null; + } + + if (includeAbstract) { + if (classInfo.count > limit) { + return null; + } + } else { + if (classInfo.concreteCount > limit) { + return null; + } + } + + + String[] result = new String[includeAbstract ? classInfo.count : classInfo.concreteCount]; + collectSubclasses(className, result, 0, new HashSet<>(), includeAbstract); + return Arrays.asList(result); + } + + List getMethods(MethodDescriptor descriptor) { + return methodImplementations.get(descriptor); + } + + private int collectSubclasses(String className, String[] consumer, int index, Set visited, + boolean includeAbstract) { + if (!visited.add(className)) { + return index; + } + + ClassInfo classInfo = classes.get(className); + if (classInfo == null) { + return index; + } + + if (includeAbstract || classInfo.concrete) { + consumer[index++] = className; + } + if (classInfo.directSubclasses != null) { + for (String subclassName : classInfo.directSubclasses) { + index = collectSubclasses(subclassName, consumer, index, visited, includeAbstract); + } + } + + return index; + } + + static class ClassInfo { + int count; + int concreteCount; + boolean concrete; + List directSubclasses; + } +} diff --git a/core/src/main/java/org/teavm/model/optimization/Inlining.java b/core/src/main/java/org/teavm/model/optimization/Inlining.java index f284571a5..f696ede36 100644 --- a/core/src/main/java/org/teavm/model/optimization/Inlining.java +++ b/core/src/main/java/org/teavm/model/optimization/Inlining.java @@ -69,6 +69,7 @@ public class Inlining { private MethodUsageCounter usageCounter; private Set methodsUsedOnce = new HashSet<>(); private boolean devirtualization; + private ClassInference classInference; public Inlining(ClassHierarchy hierarchy, DependencyInfo dependencyInfo, InliningStrategy strategy, ListableClassReaderSource classes, Predicate externalMethods, @@ -393,8 +394,10 @@ public class Inlining { } private void devirtualize(Program program, MethodReference method, DependencyInfo dependencyInfo) { - ClassInference inference = new ClassInference(dependencyInfo, hierarchy); - inference.infer(program, method); + if (classInference == null) { + classInference = new ClassInference(dependencyInfo, hierarchy, classes.getClassNames(), 30); + } + classInference.infer(program, method); for (BasicBlock block : program.getBasicBlocks()) { for (Instruction instruction : block) { @@ -407,11 +410,19 @@ public class Inlining { } Set implementations = new HashSet<>(); - for (String className : inference.classesOf(invoke.getInstance().getIndex())) { - MethodReference rawMethod = new MethodReference(className, invoke.getMethod().getDescriptor()); - MethodReader resolvedMethod = dependencyInfo.getClassSource().resolveImplementation(rawMethod); - if (resolvedMethod != null) { - implementations.add(resolvedMethod.getReference()); + if (classInference.isOverflow(invoke.getInstance().getIndex())) { + List knownImplementations = classInference.getMethodImplementations( + invoke.getMethod().getDescriptor()); + if (knownImplementations != null) { + implementations.addAll(knownImplementations); + } + } else { + for (String className : classInference.classesOf(invoke.getInstance().getIndex())) { + MethodReference rawMethod = new MethodReference(className, invoke.getMethod().getDescriptor()); + MethodReader resolvedMethod = dependencyInfo.getClassSource().resolveImplementation(rawMethod); + if (resolvedMethod != null) { + implementations.add(resolvedMethod.getReference()); + } } } @@ -423,7 +434,7 @@ public class Inlining { } } - private class PlanEntry { + static class PlanEntry { int targetBlock; Instruction targetInstruction; MethodReference method; diff --git a/tests/src/test/java/org/teavm/dependency/DependencyTest.java b/tests/src/test/java/org/teavm/dependency/DependencyTest.java index 8419e8883..611266198 100644 --- a/tests/src/test/java/org/teavm/dependency/DependencyTest.java +++ b/tests/src/test/java/org/teavm/dependency/DependencyTest.java @@ -143,13 +143,13 @@ public class DependencyTest { MethodHolder method = classSource.get(testMethod.getClassName()).getMethod(testMethod.getDescriptor()); List assertions = collectAssertions(method); processAssertions(assertions, vm.getDependencyInfo().getMethod(testMethod), vm.getDependencyInfo(), - method.getProgram()); + method.getProgram(), vm.getClasses()); } private void processAssertions(List assertions, MethodDependencyInfo methodDep, - DependencyInfo dependencyInfo, Program program) { + DependencyInfo dependencyInfo, Program program, Iterable classNames) { ClassInference classInference = new ClassInference(dependencyInfo, new ClassHierarchy( - dependencyInfo.getClassSource())); + dependencyInfo.getClassSource()), classNames, 10); classInference.infer(program, methodDep.getReference()); for (Assertion assertion : assertions) { @@ -160,10 +160,12 @@ public class DependencyTest { Arrays.sort(expectedTypes); Assert.assertArrayEquals("Assertion at " + assertion.location, expectedTypes, actualTypes); - actualTypes = classInference.classesOf(assertion.value); - Arrays.sort(actualTypes); - Assert.assertArrayEquals("Assertion at " + assertion.location + " (class inference)", - expectedTypes, actualTypes); + if (!classInference.isOverflow(assertion.value)) { + Set actualTypeSet = new HashSet<>(Arrays.asList(classInference.classesOf(assertion.value))); + Assert.assertTrue("Assertion at " + assertion.location + " (class inference), " + + "expected: " + Arrays.toString(expectedTypes) + ", actual: " + actualTypeSet, + actualTypeSet.containsAll(Arrays.asList(expectedTypes))); + } } }