diff --git a/core/src/main/java/org/teavm/dependency/DependencyNode.java b/core/src/main/java/org/teavm/dependency/DependencyNode.java index 30683cc3f..d5e645c31 100644 --- a/core/src/main/java/org/teavm/dependency/DependencyNode.java +++ b/core/src/main/java/org/teavm/dependency/DependencyNode.java @@ -434,6 +434,14 @@ public class DependencyNode implements ValueDependencyInfo { return i == result.length ? result : Arrays.copyOf(result, i); } + @Override + public boolean hasMoreTypesThan(int limit) { + if (typeSet == null) { + return false; + } + return typeSet.hasMoreTypesThan(limit, typeFilter != null ? getFilter()::match : null); + } + DependencyType[] getTypesInternal() { if (typeSet == null) { return new DependencyType[0]; diff --git a/core/src/main/java/org/teavm/dependency/TypeSet.java b/core/src/main/java/org/teavm/dependency/TypeSet.java index ef83804c0..34e9c1bef 100644 --- a/core/src/main/java/org/teavm/dependency/TypeSet.java +++ b/core/src/main/java/org/teavm/dependency/TypeSet.java @@ -23,6 +23,7 @@ import java.util.BitSet; import java.util.LinkedHashSet; import java.util.List; import java.util.Set; +import java.util.function.Predicate; class TypeSet { private static final int SMALL_TYPES_THRESHOLD = 3; @@ -89,6 +90,41 @@ class TypeSet { } } + boolean hasMoreTypesThan(int limit, Predicate filter) { + if (this.types != null) { + if (filter == null) { + return this.types.cardinality() > limit; + } + for (int index = this.types.nextSetBit(0); index >= 0; index = this.types.nextSetBit(index + 1)) { + DependencyType type = dependencyAnalyzer.types.get(index); + if (filter.test(type)) { + if (--limit < 0) { + return true; + } + } + } + return false; + } else if (this.smallTypes != null) { + if (this.smallTypes.length <= limit) { + return false; + } + if (filter == null) { + return true; + } + for (int i = 0; i < smallTypes.length; ++i) { + DependencyType type = dependencyAnalyzer.types.get(smallTypes[i]); + if (filter.test(type)) { + if (--limit < 0) { + return true; + } + } + } + return false; + } else { + return false; + } + } + DependencyType[] getTypesForNode(DependencyNode sourceNode, DependencyNode targetNode, DependencyTypeFilter filter) { int j = 0; diff --git a/core/src/main/java/org/teavm/dependency/ValueDependencyInfo.java b/core/src/main/java/org/teavm/dependency/ValueDependencyInfo.java index 6dfdab0a8..5aa96ecd3 100644 --- a/core/src/main/java/org/teavm/dependency/ValueDependencyInfo.java +++ b/core/src/main/java/org/teavm/dependency/ValueDependencyInfo.java @@ -20,6 +20,8 @@ public interface ValueDependencyInfo { boolean hasType(String type); + boolean hasMoreTypesThan(int limit); + boolean hasArrayType(); ValueDependencyInfo getArrayItem(); diff --git a/core/src/main/java/org/teavm/model/analysis/ClassInference.java b/core/src/main/java/org/teavm/model/analysis/ClassInference.java index 8d250dded..6f02869b2 100644 --- a/core/src/main/java/org/teavm/model/analysis/ClassInference.java +++ b/core/src/main/java/org/teavm/model/analysis/ClassInference.java @@ -39,6 +39,7 @@ import org.teavm.model.ClassHierarchy; import org.teavm.model.ClassReaderSource; import org.teavm.model.Incoming; import org.teavm.model.Instruction; +import org.teavm.model.MethodDescriptor; import org.teavm.model.MethodReader; import org.teavm.model.MethodReference; import org.teavm.model.Phi; @@ -68,6 +69,7 @@ import org.teavm.model.instructions.UnwrapArrayInstruction; public class ClassInference { private DependencyInfo dependencyInfo; private ClassHierarchy hierarchy; + private SubclassListProvider subclassListProvider; private Graph assignmentGraph; private Graph cloneGraph; private Graph arrayGraph; @@ -76,11 +78,13 @@ public class ClassInference { private ValueCast[] casts; private int[] exceptions; private VirtualCallSite[] virtualCallSites; + private int overflowLimit; private int[] propagationPath; private int[] nodeMapping; private IntHashSet[] types; + private boolean[] overflowTypes; private ObjectIntMap typeMap = new ObjectIntHashMap<>(); private List typeList = new ArrayList<>(); @@ -90,9 +94,12 @@ public class ClassInference { private static final int MAX_DEGREE = 3; - public ClassInference(DependencyInfo dependencyInfo, ClassHierarchy hierarchy) { + public ClassInference(DependencyInfo dependencyInfo, ClassHierarchy hierarchy, + Iterable classNames, int overflowLimit) { this.dependencyInfo = dependencyInfo; this.hierarchy = hierarchy; + this.overflowLimit = overflowLimit; + subclassListProvider = new SubclassListProvider(dependencyInfo.getClassSource(), classNames, overflowLimit); } public void infer(Program program, MethodReference methodReference) { @@ -112,6 +119,7 @@ public class ClassInference { */ types = new IntHashSet[program.variableCount() << 3]; + overflowTypes = new boolean[program.variableCount() << 3]; nodeChanged = new boolean[types.length]; formerNodeChanged = new boolean[nodeChanged.length]; nodeMapping = new int[types.length]; @@ -129,9 +137,7 @@ public class ClassInference { if (paramDep != null) { int degree = 0; while (degree <= MAX_DEGREE) { - for (String paramType : paramDep.getTypes()) { - addType(i, degree, paramType); - } + addTypesFrom(i, degree, paramDep); if (!paramDep.hasArrayType()) { break; } @@ -166,6 +172,14 @@ public class ClassInference { nodeChanged = null; } + public boolean isOverflow(int variableIndex) { + return overflowTypes[nodeMapping[packNodeAndDegree(variableIndex, 0)]]; + } + + public List getMethodImplementations(MethodDescriptor descriptor) { + return subclassListProvider.getMethods(descriptor); + } + public String[] classesOf(int variableIndex) { IntHashSet typeSet = types[nodeMapping[packNodeAndDegree(variableIndex, 0)]]; if (typeSet == null) { @@ -210,7 +224,7 @@ public class ClassInference { IntStack stack = new IntStack(); for (int i = 0; i < types.length; ++i) { - if (types[i] != null) { + if (types[i] != null || overflowTypes[i]) { stack.push(i); } } @@ -286,8 +300,10 @@ public class ClassInference { boolean[] nodeChangedBackup = nodeChanged.clone(); IntHashSet[] typesBackup = types.clone(); + boolean[] overflowTypesBackup = overflowTypes.clone(); Arrays.fill(nodeChanged, false); Arrays.fill(types, null); + Arrays.fill(overflowTypes, false); GraphBuilder graphBuilder = new GraphBuilder(graph.size()); for (int i = 0; i < graph.size(); ++i) { @@ -300,8 +316,17 @@ public class ClassInference { } int node = nodeMapping[i]; - if (typesBackup[i] != null) { - getNodeTypes(node).addAll(typesBackup[i]); + if (overflowTypesBackup[i]) { + overflowTypes[i] = true; + types[node] = null; + } else if (typesBackup[i] != null && !overflowTypes[i]) { + IntHashSet nodeTypes = getNodeTypes(node); + if (nodeTypes.addAll(typesBackup[i]) > 0) { + if (nodeTypes.size() > overflowLimit) { + types[node] = null; + overflowTypes[node] = true; + } + } } if (nodeChangedBackup[i]) { @@ -323,7 +348,7 @@ public class ClassInference { IntStack stack = new IntStack(); for (int i = 0; i < graph.size(); ++i) { - if (graph.incomingEdgesCount(i) == 0 && types[i] != null) { + if (graph.incomingEdgesCount(i) == 0 && (overflowTypes[i] || types[i] != null)) { stack.push(i); } } @@ -381,6 +406,10 @@ public class ClassInference { private void propagateAlongDAG() { for (int i = propagationPath.length - 1; i >= 0; --i) { int node = propagationPath[i]; + if (overflowTypes[node]) { + continue; + } + boolean predecessorsChanged = false; for (int predecessor : graph.incomingEdges(node)) { if (formerNodeChanged[predecessor] || nodeChanged[predecessor]) { @@ -395,9 +424,25 @@ public class ClassInference { IntHashSet nodeTypes = getNodeTypes(node); for (int predecessor : graph.incomingEdges(node)) { if (formerNodeChanged[predecessor] || nodeChanged[predecessor]) { - if (nodeTypes.addAll(types[predecessor]) > 0) { - nodeChanged[node] = true; - changed = true; + if (overflowTypes[predecessor]) { + if (!overflowTypes[node]) { + overflowTypes[node] = true; + types[node] = null; + changed = true; + nodeChanged[node] = true; + break; + } + } else if (types[predecessor] != null && nodeTypes.addAll(types[predecessor]) > 0) { + if (nodeTypes.size() > overflowLimit) { + types[node] = null; + overflowTypes[node] = true; + nodeChanged[node] = true; + changed = true; + break; + } else { + nodeChanged[node] = true; + changed = true; + } } } } @@ -406,34 +451,92 @@ public class ClassInference { private void propagateAlongCasts() { for (ValueCast cast : casts) { + int toNode = nodeMapping[packNodeAndDegree(cast.toVariable, 0)]; + if (overflowTypes[toNode]) { + continue; + } + int fromNode = nodeMapping[packNodeAndDegree(cast.fromVariable, 0)]; if (!formerNodeChanged[fromNode] && !nodeChanged[fromNode]) { continue; } - int toNode = nodeMapping[packNodeAndDegree(cast.toVariable, 0)]; IntHashSet targetTypes = getNodeTypes(toNode); - for (IntCursor cursor : types[fromNode]) { - if (targetTypes.contains(cursor.value)) { - continue; + if (overflowTypes[fromNode]) { + int degree = 0; + ValueType targetType = cast.targetType; + while (targetType instanceof ValueType.Array) { + targetType = ((ValueType.Array) targetType).getItemType(); + ++degree; } - String className = typeList.get(cursor.value); - - ValueType type; - if (className.startsWith("[")) { - type = ValueType.parseIfPossible(className); - if (type == null) { - type = ValueType.arrayOf(ValueType.object("java.lang.Object")); + if (targetType instanceof ValueType.Object) { + String targetClassName = ((ValueType.Object) targetType).getClassName(); + List subclasses = subclassListProvider.getSubclasses( + targetClassName, degree > 0); + if (subclasses == null) { + types[toNode] = null; + overflowTypes[toNode] = true; + nodeChanged[toNode] = true; + changed = true; + } else { + for (String subclass : subclasses) { + if (degree > 0) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < degree; ++i) { + sb.append('['); + } + subclass = sb.append('L').append(subclass.replace('.', '/')).append(';').toString(); + } + int typeId = getTypeByName(subclass); + if (targetTypes.add(typeId)) { + changed = true; + nodeChanged[toNode] = true; + if (targetTypes.size() > overflowLimit) { + overflowTypes[toNode] = true; + types[toNode] = null; + break; + } + } + } } } else { - type = ValueType.object(className); + int typeId = getTypeByName(targetType.toString()); + if (targetTypes.add(typeId)) { + changed = true; + nodeChanged[toNode] = true; + if (targetTypes.size() > overflowLimit) { + overflowTypes[toNode] = true; + types[toNode] = null; + } + } } + } else { + for (IntCursor cursor : types[fromNode]) { + if (targetTypes.contains(cursor.value)) { + continue; + } + String className = typeList.get(cursor.value); - if (hierarchy.isSuperType(cast.targetType, type, false)) { - changed = true; - nodeChanged[toNode] = true; - targetTypes.add(cursor.value); + ValueType type; + if (className.startsWith("[")) { + type = ValueType.parseIfPossible(className); + if (type == null) { + type = ValueType.arrayOf(ValueType.object("java.lang.Object")); + } + } else { + type = ValueType.object(className); + } + + if (hierarchy.isSuperType(cast.targetType, type, false) && targetTypes.add(cursor.value)) { + changed = true; + nodeChanged[toNode] = true; + if (targetTypes.size() > overflowLimit) { + overflowTypes[toNode] = true; + types[toNode] = null; + break; + } + } } } } @@ -443,17 +546,49 @@ public class ClassInference { ClassReaderSource classSource = dependencyInfo.getClassSource(); for (VirtualCallSite callSite : virtualCallSites) { + if (callSite.receiverOverflow) { + continue; + } + int instanceNode = nodeMapping[packNodeAndDegree(callSite.instance, 0)]; if (!formerNodeChanged[instanceNode] && !nodeChanged[instanceNode]) { continue; } - for (IntCursor type : types[instanceNode]) { - if (!callSite.knownClasses.add(type.value)) { + List receiverTypes; + if (overflowTypes[instanceNode]) { + callSite.receiverOverflow = true; + List subclasses = subclassListProvider.getSubclasses( + callSite.method.getClassName(), true); + if (subclasses != null) { + receiverTypes = subclasses; + } else { + List implementations = subclassListProvider.getMethods( + callSite.method.getDescriptor()); + if (implementations != null) { + for (MethodReference methodReference : implementations) { + mountVirtualMethod(program, callSite, methodReference); + } + } + continue; + } + } else { + List instanceNodeTypes = new ArrayList<>(); + for (IntCursor type : types[instanceNode]) { + if (callSite.knownClasses.contains(type.value)) { + continue; + } + instanceNodeTypes.add(typeList.get(type.value)); + } + receiverTypes = instanceNodeTypes; + } + + for (String className : receiverTypes) { + int typeId = getTypeByName(className); + if (!callSite.knownClasses.add(typeId)) { continue; } - String className = typeList.get(type.value); MethodReference rawMethod = new MethodReference(className, callSite.method.getDescriptor()); MethodReader resolvedMethod = classSource.resolveImplementation(rawMethod); @@ -462,28 +597,32 @@ public class ClassInference { } MethodReference resolvedMethodRef = resolvedMethod.getReference(); - if (!callSite.resolvedMethods.add(resolvedMethodRef)) { - continue; - } - - MethodDependencyInfo methodDep = dependencyInfo.getMethod(resolvedMethodRef); - if (methodDep == null) { - continue; - } - if (callSite.receiver >= 0) { - readValue(methodDep.getResult(), program.variableAt(callSite.receiver)); - } - for (int i = 0; i < callSite.arguments.length; ++i) { - writeValue(methodDep.getVariable(i + 1), program.variableAt(callSite.arguments[i])); - } - - for (String thrownTypeName : methodDep.getThrown().getTypes()) { - propagateException(thrownTypeName, program.basicBlockAt(callSite.block)); - } + mountVirtualMethod(program, callSite, resolvedMethodRef); } } } + private void mountVirtualMethod(Program program, VirtualCallSite callSite, MethodReference methodReference) { + if (!callSite.resolvedMethods.add(methodReference)) { + return; + } + + MethodDependencyInfo methodDep = dependencyInfo.getMethod(methodReference); + if (methodDep == null) { + return; + } + if (callSite.receiver >= 0) { + readValue(methodDep.getResult(), program.variableAt(callSite.receiver)); + } + for (int i = 0; i < callSite.arguments.length; ++i) { + writeValue(methodDep.getVariable(i + 1), program.variableAt(callSite.arguments[i])); + } + + for (String thrownTypeName : methodDep.getThrown().getTypes()) { + propagateException(thrownTypeName, program.basicBlockAt(callSite.block)); + } + } + private void propagateAlongExceptions(Program program) { for (int i = 0; i < exceptions.length; i += 2) { int variable = nodeMapping[packNodeAndDegree(exceptions[i], 0)]; @@ -492,9 +631,13 @@ public class ClassInference { } BasicBlock block = program.basicBlockAt(exceptions[i + 1]); - for (IntCursor type : types[variable]) { - String typeName = typeList.get(type.value); - propagateException(typeName, block); + if (overflowTypes[variable]) { + propagateOverflowException(block); + } else { + for (IntCursor type : types[variable]) { + String typeName = typeList.get(type.value); + propagateException(typeName, block); + } } } } @@ -508,10 +651,17 @@ public class ClassInference { } int exceptionNode = packNodeAndDegree(tryCatch.getHandler().getExceptionVariable().getIndex(), 0); exceptionNode = nodeMapping[exceptionNode]; - int thrownType = getTypeByName(thrownTypeName); - if (getNodeTypes(exceptionNode).add(thrownType)) { - nodeChanged[exceptionNode] = true; - changed = true; + if (!overflowTypes[exceptionNode]) { + int thrownType = getTypeByName(thrownTypeName); + IntHashSet nodeTypes = getNodeTypes(exceptionNode); + if (nodeTypes.add(thrownType)) { + nodeChanged[exceptionNode] = true; + changed = true; + if (nodeTypes.size() > overflowLimit) { + types[exceptionNode] = null; + overflowTypes[exceptionNode] = true; + } + } } break; @@ -519,6 +669,45 @@ public class ClassInference { } } + private void propagateOverflowException(BasicBlock block) { + for (TryCatchBlock tryCatch : block.getTryCatchBlocks()) { + Variable exceptionVar = tryCatch.getHandler().getExceptionVariable(); + if (exceptionVar == null) { + continue; + } + + int exceptionNode = packNodeAndDegree(exceptionVar.getIndex(), 0); + exceptionNode = nodeMapping[exceptionNode]; + if (overflowTypes[exceptionNode]) { + continue; + } + + String expectedType = tryCatch.getExceptionType(); + List thrownTypes = subclassListProvider.getSubclasses(expectedType, false); + if (thrownTypes == null) { + if (!overflowTypes[exceptionNode]) { + overflowTypes[exceptionNode] = true; + changed = true; + nodeChanged[exceptionNode] = true; + } + } else { + IntHashSet nodeTypes = getNodeTypes(exceptionNode); + for (String thrownTypeName : thrownTypes) { + int thrownType = getTypeByName(thrownTypeName); + if (nodeTypes.add(thrownType)) { + nodeChanged[exceptionNode] = true; + changed = true; + if (nodeTypes.size() > overflowLimit) { + types[exceptionNode] = null; + overflowTypes[exceptionNode] = true; + break; + } + } + } + } + } + } + IntHashSet getNodeTypes(int node) { IntHashSet result = types[node]; if (result == null) { @@ -651,9 +840,7 @@ public class ClassInference { int resultDegree = 0; while (resultDependency.hasArrayType() && resultDegree <= MAX_DEGREE) { resultDependency = resultDependency.getArrayItem(); - for (String paramType : resultDependency.getTypes()) { - addType(insn.getValueToReturn().getIndex(), resultDegree, paramType); - } + addTypesFrom(insn.getValueToReturn().getIndex(), resultDegree, resultDependency); ++resultDegree; } } @@ -706,6 +893,7 @@ public class ClassInference { static class VirtualCallSite { int instance; + boolean receiverOverflow; IntSet knownClasses = new IntHashSet(); Set resolvedMethods = new HashSet<>(); MethodReference method; @@ -718,9 +906,7 @@ public class ClassInference { int depth = 0; boolean hasArrayType; do { - for (String type : valueDep.getTypes()) { - addType(receiver.getIndex(), depth, type); - } + addTypesFrom(receiver.getIndex(), depth, valueDep); depth++; hasArrayType = valueDep.hasArrayType(); valueDep = valueDep.getArrayItem(); @@ -732,18 +918,58 @@ public class ClassInference { while (valueDep.hasArrayType() && depth < MAX_DEGREE) { depth++; valueDep = valueDep.getArrayItem(); - for (String type : valueDep.getTypes()) { - addType(source.getIndex(), depth, type); + addTypesFrom(source.getIndex(), depth, valueDep); + } + } + + boolean addType(int variable, int degree, String typeName) { + return addTypeImpl(packNodeAndDegree(variable, degree), getTypeByName(typeName)); + } + + boolean addTypeImpl(int node, int typeId) { + if (overflowTypes[node]) { + return true; + } + + IntHashSet nodeTypes = getNodeTypes(node); + if (nodeTypes.add(typeId)) { + nodeChanged[node] = true; + changed = true; + if (nodeTypes.size() > overflowLimit) { + types[node] = null; + overflowTypes[node] = true; + return true; + } + } + + return false; + } + + void addTypesFrom(int variable, int degree, ValueDependencyInfo dep) { + if (overflowTypes[packNodeAndDegree(variable, degree)]) { + return; + } + if (dep.hasMoreTypesThan(overflowLimit)) { + overflowType(variable, degree); + } else { + String[] types = dep.getTypes(); + for (String type : dep.getTypes()) { + if (addType(variable, degree, type)) { + break; + } } } } - void addType(int variable, int degree, String typeName) { + void overflowType(int variable, int degree) { int entry = nodeMapping[packNodeAndDegree(variable, degree)]; - if (getNodeTypes(entry).add(getTypeByName(typeName))) { - nodeChanged[entry] = true; - changed = true; + if (overflowTypes[entry]) { + return; } + + overflowTypes[entry] = true; + nodeChanged[entry] = true; + changed = true; } static int extractNode(int nodeAndDegree) { @@ -758,18 +984,6 @@ public class ClassInference { return (node << 3) | degree; } - static long packTwoIntegers(int a, int b) { - return ((long) a << 32) | b; - } - - static int unpackFirst(long pair) { - return (int) (pair >>> 32); - } - - static int unpackSecond(long pair) { - return (int) pair; - } - static final class ValueCast { final int fromVariable; final int toVariable; diff --git a/core/src/main/java/org/teavm/model/analysis/SubclassListProvider.java b/core/src/main/java/org/teavm/model/analysis/SubclassListProvider.java new file mode 100644 index 000000000..3b6772cea --- /dev/null +++ b/core/src/main/java/org/teavm/model/analysis/SubclassListProvider.java @@ -0,0 +1,175 @@ +/* + * Copyright 2019 konsoletyper. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.teavm.model.analysis; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.teavm.model.AccessLevel; +import org.teavm.model.ClassReader; +import org.teavm.model.ClassReaderSource; +import org.teavm.model.ElementModifier; +import org.teavm.model.MethodDescriptor; +import org.teavm.model.MethodReader; +import org.teavm.model.MethodReference; + +class SubclassListProvider { + private Map classes = new HashMap<>(); + private Map> methodImplementations = new HashMap<>(); + private int limit; + + SubclassListProvider(ClassReaderSource classSource, Iterable classNames, int limit) { + this.limit = limit; + + for (String className : classNames) { + registerClass(classSource, className); + } + } + + private ClassInfo registerClass(ClassReaderSource classSource, String className) { + ClassInfo classInfo = classes.get(className); + if (classInfo == null) { + classInfo = new ClassInfo(); + classes.put(className, classInfo); + + ClassReader cls = classSource.get(className); + if (cls != null) { + if (!cls.hasModifier(ElementModifier.INTERFACE) && !cls.hasModifier(ElementModifier.ABSTRACT)) { + classInfo.concrete = true; + } + increaseClassCount(classSource, className, new HashSet<>(), classInfo.concrete); + + if (cls.getParent() != null) { + ClassInfo parentInfo = registerClass(classSource, cls.getParent()); + if (parentInfo.directSubclasses == null) { + parentInfo.directSubclasses = new ArrayList<>(); + } + parentInfo.directSubclasses.add(className); + } + + for (String itf : cls.getInterfaces()) { + ClassInfo parentInfo = registerClass(classSource, itf); + if (parentInfo.directSubclasses == null) { + parentInfo.directSubclasses = new ArrayList<>(); + } + parentInfo.directSubclasses.add(className); + } + + for (MethodReader method : cls.getMethods()) { + if (method.hasModifier(ElementModifier.STATIC) + || method.hasModifier(ElementModifier.ABSTRACT) + || method.getLevel() == AccessLevel.PRIVATE) { + continue; + } + List implementations = methodImplementations.get(method.getDescriptor()); + if (implementations == null) { + implementations = new ArrayList<>(); + methodImplementations.put(method.getDescriptor(), implementations); + } + implementations.add(method.getReference()); + } + } + } + + return classInfo; + } + + private void increaseClassCount(ClassReaderSource classSource, String className, Set visited, + boolean concrete) { + if (!visited.add(className)) { + return; + } + + ClassInfo classInfo = registerClass(classSource, className); + if ((!concrete || classInfo.concreteCount > limit) && classInfo.count > limit) { + return; + } + classInfo.count++; + if (concrete) { + classInfo.concreteCount++; + } + + ClassReader cls = classSource.get(className); + if (cls != null) { + if (cls.getParent() != null) { + increaseClassCount(classSource, cls.getParent(), visited, concrete); + } + for (String itf : cls.getInterfaces()) { + increaseClassCount(classSource, itf, visited, concrete); + } + } + } + + List getSubclasses(String className, boolean includeAbstract) { + ClassInfo classInfo = classes.get(className); + if (classInfo == null) { + return null; + } + + if (includeAbstract) { + if (classInfo.count > limit) { + return null; + } + } else { + if (classInfo.concreteCount > limit) { + return null; + } + } + + + String[] result = new String[includeAbstract ? classInfo.count : classInfo.concreteCount]; + collectSubclasses(className, result, 0, new HashSet<>(), includeAbstract); + return Arrays.asList(result); + } + + List getMethods(MethodDescriptor descriptor) { + return methodImplementations.get(descriptor); + } + + private int collectSubclasses(String className, String[] consumer, int index, Set visited, + boolean includeAbstract) { + if (!visited.add(className)) { + return index; + } + + ClassInfo classInfo = classes.get(className); + if (classInfo == null) { + return index; + } + + if (includeAbstract || classInfo.concrete) { + consumer[index++] = className; + } + if (classInfo.directSubclasses != null) { + for (String subclassName : classInfo.directSubclasses) { + index = collectSubclasses(subclassName, consumer, index, visited, includeAbstract); + } + } + + return index; + } + + static class ClassInfo { + int count; + int concreteCount; + boolean concrete; + List directSubclasses; + } +} diff --git a/core/src/main/java/org/teavm/model/optimization/Inlining.java b/core/src/main/java/org/teavm/model/optimization/Inlining.java index f284571a5..f696ede36 100644 --- a/core/src/main/java/org/teavm/model/optimization/Inlining.java +++ b/core/src/main/java/org/teavm/model/optimization/Inlining.java @@ -69,6 +69,7 @@ public class Inlining { private MethodUsageCounter usageCounter; private Set methodsUsedOnce = new HashSet<>(); private boolean devirtualization; + private ClassInference classInference; public Inlining(ClassHierarchy hierarchy, DependencyInfo dependencyInfo, InliningStrategy strategy, ListableClassReaderSource classes, Predicate externalMethods, @@ -393,8 +394,10 @@ public class Inlining { } private void devirtualize(Program program, MethodReference method, DependencyInfo dependencyInfo) { - ClassInference inference = new ClassInference(dependencyInfo, hierarchy); - inference.infer(program, method); + if (classInference == null) { + classInference = new ClassInference(dependencyInfo, hierarchy, classes.getClassNames(), 30); + } + classInference.infer(program, method); for (BasicBlock block : program.getBasicBlocks()) { for (Instruction instruction : block) { @@ -407,11 +410,19 @@ public class Inlining { } Set implementations = new HashSet<>(); - for (String className : inference.classesOf(invoke.getInstance().getIndex())) { - MethodReference rawMethod = new MethodReference(className, invoke.getMethod().getDescriptor()); - MethodReader resolvedMethod = dependencyInfo.getClassSource().resolveImplementation(rawMethod); - if (resolvedMethod != null) { - implementations.add(resolvedMethod.getReference()); + if (classInference.isOverflow(invoke.getInstance().getIndex())) { + List knownImplementations = classInference.getMethodImplementations( + invoke.getMethod().getDescriptor()); + if (knownImplementations != null) { + implementations.addAll(knownImplementations); + } + } else { + for (String className : classInference.classesOf(invoke.getInstance().getIndex())) { + MethodReference rawMethod = new MethodReference(className, invoke.getMethod().getDescriptor()); + MethodReader resolvedMethod = dependencyInfo.getClassSource().resolveImplementation(rawMethod); + if (resolvedMethod != null) { + implementations.add(resolvedMethod.getReference()); + } } } @@ -423,7 +434,7 @@ public class Inlining { } } - private class PlanEntry { + static class PlanEntry { int targetBlock; Instruction targetInstruction; MethodReference method; diff --git a/tests/src/test/java/org/teavm/dependency/DependencyTest.java b/tests/src/test/java/org/teavm/dependency/DependencyTest.java index 8419e8883..611266198 100644 --- a/tests/src/test/java/org/teavm/dependency/DependencyTest.java +++ b/tests/src/test/java/org/teavm/dependency/DependencyTest.java @@ -143,13 +143,13 @@ public class DependencyTest { MethodHolder method = classSource.get(testMethod.getClassName()).getMethod(testMethod.getDescriptor()); List assertions = collectAssertions(method); processAssertions(assertions, vm.getDependencyInfo().getMethod(testMethod), vm.getDependencyInfo(), - method.getProgram()); + method.getProgram(), vm.getClasses()); } private void processAssertions(List assertions, MethodDependencyInfo methodDep, - DependencyInfo dependencyInfo, Program program) { + DependencyInfo dependencyInfo, Program program, Iterable classNames) { ClassInference classInference = new ClassInference(dependencyInfo, new ClassHierarchy( - dependencyInfo.getClassSource())); + dependencyInfo.getClassSource()), classNames, 10); classInference.infer(program, methodDep.getReference()); for (Assertion assertion : assertions) { @@ -160,10 +160,12 @@ public class DependencyTest { Arrays.sort(expectedTypes); Assert.assertArrayEquals("Assertion at " + assertion.location, expectedTypes, actualTypes); - actualTypes = classInference.classesOf(assertion.value); - Arrays.sort(actualTypes); - Assert.assertArrayEquals("Assertion at " + assertion.location + " (class inference)", - expectedTypes, actualTypes); + if (!classInference.isOverflow(assertion.value)) { + Set actualTypeSet = new HashSet<>(Arrays.asList(classInference.classesOf(assertion.value))); + Assert.assertTrue("Assertion at " + assertion.location + " (class inference), " + + "expected: " + Arrays.toString(expectedTypes) + ", actual: " + actualTypeSet, + actualTypeSet.containsAll(Arrays.asList(expectedTypes))); + } } }