Improve speed of class inference during inlining

This commit is contained in:
Alexey Andreev 2018-02-04 00:17:02 +03:00 committed by Alexey Andreev
parent f548fc964c
commit 6d68010416
5 changed files with 599 additions and 288 deletions

View File

@ -15,6 +15,7 @@
*/ */
package org.teavm.common; package org.teavm.common;
import com.carrotsearch.hppc.IntStack;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
@ -164,64 +165,98 @@ public final class GraphUtils {
/* /*
* Tarjan's algorithm * Tarjan's algorithm
* See pseudocode at https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm
* This is a stackless version.
*/ */
public static int[][] findStronglyConnectedComponents(Graph graph, int[] start) { public static int[][] findStronglyConnectedComponents(Graph graph) {
List<int[]> components = new ArrayList<>(); List<int[]> components = new ArrayList<>();
int[] visitIndex = new int[graph.size()];
int[] headerIndex = new int[graph.size()];
IntegerStack currentComponent = new IntegerStack(1);
boolean[] inCurrentComponent = new boolean[graph.size()];
int lastIndex = 1;
IntegerStack stack = new IntegerStack(graph.size());
IntegerStack modeStack = new IntegerStack(graph.size());
int[] expectedMode = new int[graph.size()];
for (int startNode : start) { int index = 0;
stack.push(startNode); IntStack procStack = new IntStack();
modeStack.push(0); IntStack stack = new IntStack();
int[] nodeIndex = new int[graph.size()];
int[] nodeLowLink = new int[graph.size()];
boolean[] nodeOnStack = new boolean[graph.size()];
Arrays.fill(nodeIndex, -1);
Arrays.fill(nodeLowLink, -2);
for (int i = 0; i < graph.size(); ++i) {
procStack.push(i);
procStack.push(0);
} }
while (!stack.isEmpty()) { while (!procStack.isEmpty()) {
int node = stack.pop(); int state = procStack.pop();
int mode = modeStack.pop(); int v = procStack.pop();
if (expectedMode[node] != mode) {
continue;
}
expectedMode[node]++;
if (mode == 1) {
expectedMode[node] = 2;
int hdr = headerIndex[node];
for (int successor : graph.outgoingEdges(node)) {
if (visitIndex[node] < visitIndex[successor]) {
hdr = Math.min(hdr, headerIndex[successor]);
} else if (inCurrentComponent[successor]) {
hdr = Math.min(hdr, visitIndex[successor]);
}
}
headerIndex[node] = hdr;
if (hdr == visitIndex[node]) { switch (state) {
IntegerArray componentMembers = new IntegerArray(graph.size()); case 0: {
int componentMember; if (nodeIndex[v] >= 0) {
break;
}
nodeIndex[v] = index;
nodeLowLink[v] = index;
index++;
stack.push(v);
nodeOnStack[v] = true;
procStack.push(v);
procStack.push(3);
for (int w : graph.outgoingEdges(v)) {
procStack.push(w);
procStack.push(v);
procStack.push(1);
}
break;
}
case 1: {
int w = procStack.pop();
if (nodeIndex[w] < 0) {
procStack.push(w);
procStack.push(v);
procStack.push(2);
procStack.push(w);
procStack.push(0);
} else if (nodeOnStack[w]) {
nodeLowLink[v] = Math.min(nodeLowLink[v], nodeIndex[w]);
}
break;
}
case 2: {
int w = procStack.pop();
nodeLowLink[v] = Math.min(nodeLowLink[v], nodeLowLink[w]);
break;
}
case 3: {
if (nodeLowLink[v] == nodeIndex[v]) {
IntegerArray scc = new IntegerArray(4);
int w;
do { do {
componentMember = currentComponent.pop(); w = stack.pop();
inCurrentComponent[componentMember] = false; nodeOnStack[w] = false;
componentMembers.add(componentMember); scc.add(w);
} while (componentMember != node); } while (w != v);
components.add(componentMembers.getAll());
}
} else if (mode == 0) {
visitIndex[node] = lastIndex;
headerIndex[node] = lastIndex;
lastIndex++;
currentComponent.push(node); if (scc.size() > 1) {
inCurrentComponent[node] = true; components.add(scc.getAll());
stack.push(node); } else {
modeStack.push(1); for (int successor : graph.outgoingEdges(v)) {
for (int successor : graph.outgoingEdges(node)) { if (successor == v) {
stack.push(successor); components.add(scc.getAll());
modeStack.push(0); break;
}
}
}
}
break;
} }
} }
} }

View File

@ -70,7 +70,7 @@ class IrreducibleGraphConverter {
if (irreducible) { if (irreducible) {
DJGraphNodeFilter filter = new DJGraphNodeFilter(djGraph, level); DJGraphNodeFilter filter = new DJGraphNodeFilter(djGraph, level);
Graph graph = GraphUtils.subgraph(djGraph.getGraph(), filter); Graph graph = GraphUtils.subgraph(djGraph.getGraph(), filter);
int[][] sccs = GraphUtils.findStronglyConnectedComponents(graph, djGraph.level(level)); int[][] sccs = GraphUtils.findStronglyConnectedComponents(graph);
for (int[] scc : sccs) { for (int[] scc : sccs) {
if (scc.length > 1) { if (scc.length > 1) {
handleStronglyConnectedComponent(djGraph, scc, nodeMap); handleStronglyConnectedComponent(djGraph, scc, nodeMap);

View File

@ -15,21 +15,21 @@
*/ */
package org.teavm.model.analysis; package org.teavm.model.analysis;
import com.carrotsearch.hppc.IntObjectMap;
import com.carrotsearch.hppc.IntObjectOpenHashMap;
import com.carrotsearch.hppc.IntOpenHashSet; import com.carrotsearch.hppc.IntOpenHashSet;
import com.carrotsearch.hppc.IntSet; import com.carrotsearch.hppc.IntSet;
import java.util.ArrayDeque; import com.carrotsearch.hppc.IntStack;
import com.carrotsearch.hppc.ObjectIntMap;
import com.carrotsearch.hppc.ObjectIntOpenHashMap;
import com.carrotsearch.hppc.cursors.IntCursor;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Queue;
import java.util.Set; import java.util.Set;
import org.teavm.common.Graph; import org.teavm.common.Graph;
import org.teavm.common.GraphBuilder; import org.teavm.common.GraphBuilder;
import org.teavm.common.GraphUtils;
import org.teavm.common.IntegerArray;
import org.teavm.dependency.DependencyInfo; import org.teavm.dependency.DependencyInfo;
import org.teavm.dependency.FieldDependencyInfo; import org.teavm.dependency.FieldDependencyInfo;
import org.teavm.dependency.MethodDependencyInfo; import org.teavm.dependency.MethodDependencyInfo;
@ -70,28 +70,64 @@ public class ClassInference {
private Graph cloneGraph; private Graph cloneGraph;
private Graph arrayGraph; private Graph arrayGraph;
private Graph itemGraph; private Graph itemGraph;
private List<IntObjectMap<ValueType>> casts; private Graph graph;
private IntObjectMap<IntSet> exceptionMap; private ValueCast[] casts;
private VirtualCallSite[][] virtualCallSites; private int[] exceptions;
private List<Task> initialTasks; private VirtualCallSite[] virtualCallSites;
private List<List<Set<String>>> types;
private List<Set<String>> finalTypes; private int[] propagationPath;
private int[] nodeMapping;
private IntOpenHashSet[] types;
private ObjectIntMap<String> typeMap = new ObjectIntOpenHashMap<>();
private List<String> typeList = new ArrayList<>();
private boolean changed = true;
private boolean[] nodeChanged;
private boolean[] formerNodeChanged;
private static final int MAX_DEGREE = 3;
public ClassInference(DependencyInfo dependencyInfo) { public ClassInference(DependencyInfo dependencyInfo) {
this.dependencyInfo = dependencyInfo; this.dependencyInfo = dependencyInfo;
} }
public void infer(Program program, MethodReference methodReference) { public void infer(Program program, MethodReference methodReference) {
MethodDependencyInfo thisMethodDep = dependencyInfo.getMethod(methodReference); /*
buildGraphs(program, thisMethodDep); The idea behind this algorithm
1. Build preliminary graphs that represent different connection types between variables.
See `assignmentGraph`, `cloneGraph`, `arrayGraph`, `itemGraph`.
2. Build initial type sets where possible. See `types`.
3. Build additional info: casts, virtual invocations, exceptions.
See `casts`, `exceptions`, `virtualCallSites`.
4. Build graph from set of preliminary paths
5. Find strongly connected components, collapse then into one nodes.
6. Calculate topological order of the DAG (let it be propagation path).
Let resulting order be `propagationPath`.
7. Propagate types along calculated path; then propagate types using additional info.
8. Repeat 7 until it changes anything (i.e. calculate fixed point).
*/
types = new IntOpenHashSet[program.variableCount() << 3];
nodeChanged = new boolean[types.length];
formerNodeChanged = new boolean[nodeChanged.length];
nodeMapping = new int[types.length];
for (int i = 0; i < types.length; ++i) {
nodeMapping[i] = i;
}
// See 1, 2, 3
MethodDependencyInfo thisMethodDep = dependencyInfo.getMethod(methodReference);
buildPreliminaryGraphs(program, thisMethodDep);
// Augment (2) with input types of method
for (int i = 0; i <= methodReference.parameterCount(); ++i) { for (int i = 0; i <= methodReference.parameterCount(); ++i) {
ValueDependencyInfo paramDep = thisMethodDep.getVariable(i); ValueDependencyInfo paramDep = thisMethodDep.getVariable(i);
if (paramDep != null) { if (paramDep != null) {
int degree = 0; int degree = 0;
while (true) { while (degree <= MAX_DEGREE) {
for (String paramType : paramDep.getTypes()) { for (String paramType : paramDep.getTypes()) {
initialTasks.add(new Task(i, degree, paramType)); addType(i, degree, paramType);
} }
if (!paramDep.hasArrayType()) { if (!paramDep.hasArrayType()) {
break; break;
@ -102,37 +138,46 @@ public class ClassInference {
} }
} }
types = new ArrayList<>(program.variableCount()); // See 4
for (int i = 0; i < program.variableCount(); ++i) { buildPropagationGraph();
List<Set<String>> variableTypes = new ArrayList<>();
types.add(variableTypes);
for (int j = 0; j < 3; ++j) {
variableTypes.add(new LinkedHashSet<>());
}
}
// See 5
collapseSCCs();
// See 6
buildPropagationPath();
// See 7, 8
propagate(program); propagate(program);
// Cleanup
assignmentGraph = null; assignmentGraph = null;
graph = null;
cloneGraph = null; cloneGraph = null;
arrayGraph = null; arrayGraph = null;
itemGraph = null; itemGraph = null;
casts = null; casts = null;
exceptionMap = null; exceptions = null;
virtualCallSites = null; virtualCallSites = null;
propagationPath = null;
finalTypes = new ArrayList<>(program.variableCount()); nodeChanged = null;
for (int i = 0; i < program.variableCount(); ++i) {
finalTypes.add(types.get(i).get(0));
}
types = null;
} }
public String[] classesOf(int variableIndex) { public String[] classesOf(int variableIndex) {
return finalTypes.get(variableIndex).toArray(new String[0]); IntOpenHashSet typeSet = types[nodeMapping[packNodeAndDegree(variableIndex, 0)]];
if (typeSet == null) {
return new String[0];
} }
private void buildGraphs(Program program, MethodDependencyInfo thisMethodDep) { int[] typeIndexes = typeSet.toArray();
String[] types = new String[typeIndexes.length];
for (int i = 0; i < typeIndexes.length; ++i) {
types[i] = typeList.get(typeIndexes[i]);
}
return types;
}
private void buildPreliminaryGraphs(Program program, MethodDependencyInfo thisMethodDep) {
GraphBuildingVisitor visitor = new GraphBuildingVisitor(program.variableCount(), dependencyInfo); GraphBuildingVisitor visitor = new GraphBuildingVisitor(program.variableCount(), dependencyInfo);
visitor.thisMethodDep = thisMethodDep; visitor.thisMethodDep = thisMethodDep;
for (BasicBlock block : program.getBasicBlocks()) { for (BasicBlock block : program.getBasicBlocks()) {
@ -143,161 +188,367 @@ public class ClassInference {
for (Instruction insn : block) { for (Instruction insn : block) {
insn.acceptVisitor(visitor); insn.acceptVisitor(visitor);
} }
if (block.getExceptionVariable() != null) {
getNodeTypes(packNodeAndDegree(block.getExceptionVariable().getIndex(), 0));
}
} }
assignmentGraph = visitor.assignmentGraphBuilder.build(); assignmentGraph = visitor.assignmentGraphBuilder.build();
cloneGraph = visitor.cloneGraphBuilder.build(); cloneGraph = visitor.cloneGraphBuilder.build();
arrayGraph = visitor.arrayGraphBuilder.build(); arrayGraph = visitor.arrayGraphBuilder.build();
itemGraph = visitor.itemGraphBuilder.build(); itemGraph = visitor.itemGraphBuilder.build();
casts = visitor.casts; casts = visitor.casts.toArray(new ValueCast[0]);
exceptionMap = visitor.exceptionMap; exceptions = visitor.exceptions.getAll();
initialTasks = visitor.tasks; virtualCallSites = visitor.virtualCallSites.toArray(new VirtualCallSite[0]);
}
virtualCallSites = new VirtualCallSite[program.variableCount()][]; private void buildPropagationGraph() {
for (int i = 0; i < virtualCallSites.length; ++i) { IntStack stack = new IntStack();
List<VirtualCallSite> buildCallSites = visitor.virtualCallSites.get(i);
if (buildCallSites != null) { for (int i = 0; i < types.length; ++i) {
virtualCallSites[i] = buildCallSites.toArray(new VirtualCallSite[0]); if (types[i] != null) {
stack.push(i);
} }
} }
boolean[] visited = new boolean[types.length];
GraphBuilder graphBuilder = new GraphBuilder(types.length);
while (!stack.isEmpty()) {
int entry = stack.pop();
if (visited[entry]) {
continue;
}
visited[entry] = true;
int degree = extractDegree(entry);
int variable = extractNode(entry);
// Actually, successor nodes in resulting graph
IntSet nextEntries = new IntOpenHashSet();
// Start: calculating successor nodes in resulting DAG along different paths
//
for (int successor : assignmentGraph.outgoingEdges(variable)) {
nextEntries.add(packNodeAndDegree(successor, degree));
}
for (int successor : cloneGraph.outgoingEdges(variable)) {
nextEntries.add(packNodeAndDegree(successor, degree));
}
if (degree > 0) {
for (int predecessor : assignmentGraph.incomingEdges(variable)) {
nextEntries.add(packNodeAndDegree(predecessor, degree));
}
for (int successor : itemGraph.outgoingEdges(variable)) {
nextEntries.add(packNodeAndDegree(successor, degree - 1));
}
}
for (int successor : arrayGraph.outgoingEdges(variable)) {
nextEntries.add(packNodeAndDegree(successor, degree + 1));
}
//
// End: calculating successor nodes in resulting graph
for (IntCursor next : nextEntries) {
graphBuilder.addEdge(entry, next.value);
if (!visited[next.value]) {
stack.push(next.value);
}
}
}
graph = graphBuilder.build();
}
private void collapseSCCs() {
int[][] sccs = GraphUtils.findStronglyConnectedComponents(graph);
if (sccs.length == 0) {
return;
}
for (int[] scc : sccs) {
for (int i = 1; i < scc.length; ++i) {
nodeMapping[scc[i]] = scc[0];
}
}
boolean[] nodeChangedBackup = nodeChanged.clone();
IntOpenHashSet[] typesBackup = types.clone();
Arrays.fill(nodeChanged, false);
Arrays.fill(types, null);
GraphBuilder graphBuilder = new GraphBuilder(graph.size());
for (int i = 0; i < graph.size(); ++i) {
for (int j : graph.outgoingEdges(i)) {
int from = nodeMapping[i];
int to = nodeMapping[j];
if (from != to) {
graphBuilder.addEdge(from, to);
}
}
int node = nodeMapping[i];
if (typesBackup[i] != null) {
getNodeTypes(node).addAll(typesBackup[i]);
}
if (nodeChangedBackup[i]) {
nodeChanged[node] = true;
}
}
graph = graphBuilder.build();
}
private static final byte FRESH = 0;
private static final byte VISITING = 1;
private static final byte VISITED = 2;
private void buildPropagationPath() {
byte[] state = new byte[types.length];
int[] path = new int[types.length];
int pathSize = 0;
IntStack stack = new IntStack();
for (int i = 0; i < graph.size(); ++i) {
if (graph.incomingEdgesCount(i) == 0 && types[i] != null) {
stack.push(i);
}
}
while (!stack.isEmpty()) {
int node = stack.pop();
if (state[node] == FRESH) {
state[node] = VISITING;
stack.push(node);
for (int successor : graph.outgoingEdges(node)) {
if (state[successor] == FRESH) {
stack.push(successor);
}
}
} else if (state[node] == VISITING) {
path[pathSize++] = node;
state[node] = VISITED;
}
}
propagationPath = Arrays.copyOf(path, pathSize);
} }
private void propagate(Program program) { private void propagate(Program program) {
ClassReaderSource classSource = dependencyInfo.getClassSource(); changed = false;
Queue<Task> queue = new ArrayDeque<>(initialTasks); while (true) {
initialTasks = null; System.arraycopy(nodeChanged, 0, formerNodeChanged, 0, nodeChanged.length);
Arrays.fill(nodeChanged, false);
while (!queue.isEmpty()) { propagateAlongDAG();
Task task = queue.remove(); boolean outerChanged = changed;
if (task.degree < 0) { do {
BasicBlock block = program.basicBlockAt(task.variable); changed = false;
for (TryCatchBlock tryCatch : block.getTryCatchBlocks()) { propagateAlongCasts();
if (tryCatch.getExceptionType() == null propagateAlongVirtualCalls(program);
|| classSource.isSuperType(tryCatch.getExceptionType(), task.className).orElse(false)) { propagateAlongExceptions(program);
Variable exception = tryCatch.getHandler().getExceptionVariable(); if (changed) {
if (exception != null) { outerChanged = true;
queue.add(new Task(exception.getIndex(), 0, task.className));
} }
} while (changed);
if (!outerChanged) {
break;
}
changed = false;
}
}
private void propagateAlongDAG() {
for (int i = propagationPath.length - 1; i >= 0; --i) {
int node = propagationPath[i];
boolean predecessorsChanged = false;
for (int predecessor : graph.incomingEdges(node)) {
if (formerNodeChanged[predecessor] || nodeChanged[predecessor]) {
predecessorsChanged = true;
break; break;
} }
} }
if (!predecessorsChanged) {
continue; continue;
} }
List<Set<String>> variableTypes = types.get(task.variable); IntOpenHashSet nodeTypes = getNodeTypes(node);
if (task.degree >= variableTypes.size()) { for (int predecessor : graph.incomingEdges(node)) {
if (formerNodeChanged[predecessor] || nodeChanged[predecessor]) {
if (nodeTypes.addAll(types[predecessor]) > 0) {
nodeChanged[node] = true;
changed = true;
}
}
}
}
}
private void propagateAlongCasts() {
ClassReaderSource classSource = dependencyInfo.getClassSource();
for (ValueCast cast : casts) {
int fromNode = nodeMapping[packNodeAndDegree(cast.fromVariable, 0)];
if (!formerNodeChanged[fromNode] && !nodeChanged[fromNode]) {
continue; continue;
} }
Set<String> typeSet = variableTypes.get(task.degree); int toNode = nodeMapping[packNodeAndDegree(cast.toVariable, 0)];
if (!typeSet.add(task.className)) { IntOpenHashSet targetTypes = getNodeTypes(toNode);
for (IntCursor cursor : types[fromNode]) {
if (targetTypes.contains(cursor.value)) {
continue; continue;
} }
String className = typeList.get(cursor.value);
for (int successor : assignmentGraph.outgoingEdges(task.variable)) {
queue.add(new Task(successor, task.degree, task.className));
int itemDegree = task.degree + 1;
if (itemDegree < variableTypes.size()) {
for (String type : variableTypes.get(itemDegree)) {
queue.add(new Task(successor, itemDegree, type));
}
}
List<Set<String>> successorVariableTypes = types.get(successor);
if (itemDegree < successorVariableTypes.size()) {
for (String type : successorVariableTypes.get(itemDegree)) {
queue.add(new Task(task.variable, itemDegree, type));
}
}
}
if (task.degree > 0) {
for (int predecessor : assignmentGraph.incomingEdges(task.variable)) {
queue.add(new Task(predecessor, task.degree, task.className));
}
for (int successor : itemGraph.outgoingEdges(task.variable)) {
queue.add(new Task(successor, task.degree - 1, task.className));
}
} else {
for (int successor : cloneGraph.outgoingEdges(task.variable)) {
queue.add(new Task(successor, 0, task.className));
}
IntSet blocks = exceptionMap.get(task.variable);
if (blocks != null) {
for (int block : blocks.toArray()) {
queue.add(new Task(block, -1, task.className));
}
}
VirtualCallSite[] callSites = virtualCallSites[task.variable];
if (callSites != null) {
for (VirtualCallSite callSite : callSites) {
MethodReference rawMethod = new MethodReference(task.className,
callSite.method.getDescriptor());
MethodReader resolvedMethod = classSource.resolveImplementation(rawMethod);
if (resolvedMethod == null) {
continue;
}
MethodReference resolvedMethodRef = resolvedMethod.getReference();
if (callSite.resolvedMethods.add(resolvedMethodRef)) {
MethodDependencyInfo methodDep = dependencyInfo.getMethod(resolvedMethodRef);
if (methodDep != null) {
if (callSite.receiver >= 0) {
readValue(methodDep.getResult(), program.variableAt(callSite.receiver), queue);
}
for (int i = 0; i < callSite.arguments.length; ++i) {
writeValue(methodDep.getVariable(i + 1), program.variableAt(callSite.arguments[i]),
queue);
}
for (String type : methodDep.getThrown().getTypes()) {
queue.add(new Task(callSite.block, -1, type));
}
}
}
}
}
}
for (int successor : arrayGraph.outgoingEdges(task.variable)) {
queue.add(new Task(successor, task.degree + 1, task.className));
}
IntObjectMap<ValueType> variableCasts = casts.get(task.variable);
if (variableCasts != null) {
ValueType type; ValueType type;
if (task.className.startsWith("[")) { if (className.startsWith("[")) {
type = ValueType.parseIfPossible(task.className); type = ValueType.parseIfPossible(className);
if (type == null) { if (type == null) {
type = ValueType.arrayOf(ValueType.object("java.lang.Object")); type = ValueType.arrayOf(ValueType.object("java.lang.Object"));
} }
} else { } else {
type = ValueType.object(task.className); type = ValueType.object(className);
}
for (int target : variableCasts.keys().toArray()) {
ValueType targetType = variableCasts.get(target);
if (classSource.isSuperType(targetType, type).orElse(false)) {
queue.add(new Task(target, 0, task.className));
} }
if (classSource.isSuperType(cast.targetType, type).orElse(false)) {
changed = true;
nodeChanged[toNode] = true;
targetTypes.add(cursor.value);
} }
} }
} }
} }
static class GraphBuildingVisitor extends AbstractInstructionVisitor { private void propagateAlongVirtualCalls(Program program) {
ClassReaderSource classSource = dependencyInfo.getClassSource();
for (VirtualCallSite callSite : virtualCallSites) {
int instanceNode = nodeMapping[packNodeAndDegree(callSite.instance, 0)];
if (!formerNodeChanged[instanceNode] && !nodeChanged[instanceNode]) {
continue;
}
for (IntCursor type : types[instanceNode]) {
if (!callSite.knownClasses.add(type.value)) {
continue;
}
String className = typeList.get(type.value);
MethodReference rawMethod = new MethodReference(className, callSite.method.getDescriptor());
MethodReader resolvedMethod = classSource.resolveImplementation(rawMethod);
if (resolvedMethod == null) {
continue;
}
MethodReference resolvedMethodRef = resolvedMethod.getReference();
if (!callSite.resolvedMethods.add(resolvedMethodRef)) {
continue;
}
MethodDependencyInfo methodDep = dependencyInfo.getMethod(resolvedMethodRef);
if (methodDep == null) {
continue;
}
if (callSite.receiver >= 0) {
readValue(methodDep.getResult(), program.variableAt(callSite.receiver));
}
for (int i = 0; i < callSite.arguments.length; ++i) {
writeValue(methodDep.getVariable(i + 1), program.variableAt(callSite.arguments[i]));
}
for (String thrownTypeName : methodDep.getThrown().getTypes()) {
propagateException(thrownTypeName, program.basicBlockAt(callSite.block));
}
}
}
}
private void propagateAlongExceptions(Program program) {
for (int i = 0; i < exceptions.length; i += 2) {
int variable = nodeMapping[packNodeAndDegree(exceptions[i], 0)];
if (!formerNodeChanged[variable] && !nodeChanged[variable]) {
continue;
}
BasicBlock block = program.basicBlockAt(exceptions[i + 1]);
for (IntCursor type : types[variable]) {
String typeName = typeList.get(type.value);
propagateException(typeName, block);
}
}
}
private void propagateException(String thrownTypeName, BasicBlock block) {
ClassReaderSource classSource = dependencyInfo.getClassSource();
for (TryCatchBlock tryCatch : block.getTryCatchBlocks()) {
String expectedType = tryCatch.getExceptionType();
if (expectedType == null || classSource.isSuperType(expectedType, thrownTypeName).orElse(false)) {
if (tryCatch.getHandler().getExceptionVariable() == null) {
break;
}
int exceptionNode = packNodeAndDegree(tryCatch.getHandler().getExceptionVariable().getIndex(), 0);
exceptionNode = nodeMapping[exceptionNode];
int thrownType = getTypeByName(thrownTypeName);
if (getNodeTypes(exceptionNode).add(thrownType)) {
nodeChanged[exceptionNode] = true;
changed = true;
}
break;
}
}
}
IntOpenHashSet getNodeTypes(int node) {
IntOpenHashSet result = types[node];
if (result == null) {
result = new IntOpenHashSet();
types[node] = result;
}
return result;
}
int getTypeByName(String typeName) {
int type = typeMap.getOrDefault(typeName, -1);
if (type < 0) {
type = typeList.size();
typeMap.put(typeName, type);
typeList.add(typeName);
}
return type;
}
class GraphBuildingVisitor extends AbstractInstructionVisitor {
DependencyInfo dependencyInfo; DependencyInfo dependencyInfo;
GraphBuilder assignmentGraphBuilder; GraphBuilder assignmentGraphBuilder;
GraphBuilder cloneGraphBuilder; GraphBuilder cloneGraphBuilder;
GraphBuilder arrayGraphBuilder; GraphBuilder arrayGraphBuilder;
GraphBuilder itemGraphBuilder; GraphBuilder itemGraphBuilder;
MethodDependencyInfo thisMethodDep; MethodDependencyInfo thisMethodDep;
List<IntObjectMap<ValueType>> casts; List<ValueCast> casts = new ArrayList<>();
IntObjectMap<IntSet> exceptionMap = new IntObjectOpenHashMap<>(); IntegerArray exceptions = new IntegerArray(2);
List<Task> tasks = new ArrayList<>(); List<VirtualCallSite> virtualCallSites = new ArrayList<>();
List<List<VirtualCallSite>> virtualCallSites;
BasicBlock currentBlock; BasicBlock currentBlock;
GraphBuildingVisitor(int variableCount, DependencyInfo dependencyInfo) { GraphBuildingVisitor(int variableCount, DependencyInfo dependencyInfo) {
@ -306,12 +557,6 @@ public class ClassInference {
cloneGraphBuilder = new GraphBuilder(variableCount); cloneGraphBuilder = new GraphBuilder(variableCount);
arrayGraphBuilder = new GraphBuilder(variableCount); arrayGraphBuilder = new GraphBuilder(variableCount);
itemGraphBuilder = new GraphBuilder(variableCount); itemGraphBuilder = new GraphBuilder(variableCount);
casts = new ArrayList<>(variableCount);
for (int i = 0; i < variableCount; ++i) {
casts.add(new IntObjectOpenHashMap<>());
}
virtualCallSites = new ArrayList<>(Collections.nCopies(variableCount, null));
} }
public void visit(Phi phi) { public void visit(Phi phi) {
@ -322,12 +567,12 @@ public class ClassInference {
@Override @Override
public void visit(ClassConstantInstruction insn) { public void visit(ClassConstantInstruction insn) {
tasks.add(new Task(insn.getReceiver().getIndex(), 0, "java.lang.Class")); addType(insn.getReceiver().getIndex(), 0, "java.lang.Class");
} }
@Override @Override
public void visit(StringConstantInstruction insn) { public void visit(StringConstantInstruction insn) {
tasks.add(new Task(insn.getReceiver().getIndex(), 0, "java.lang.String")); addType(insn.getReceiver().getIndex(), 0, "java.lang.String");
} }
@Override @Override
@ -337,27 +582,24 @@ public class ClassInference {
@Override @Override
public void visit(CastInstruction insn) { public void visit(CastInstruction insn) {
casts.get(insn.getValue().getIndex()).put(insn.getReceiver().getIndex(), insn.getTargetType()); casts.add(new ValueCast(insn.getValue().getIndex(), insn.getReceiver().getIndex(), insn.getTargetType()));
getNodeTypes(packNodeAndDegree(insn.getReceiver().getIndex(), 0));
} }
@Override @Override
public void visit(RaiseInstruction insn) { public void visit(RaiseInstruction insn) {
IntSet blockIndexes = exceptionMap.get(insn.getException().getIndex()); exceptions.add(insn.getException().getIndex());
if (blockIndexes == null) { exceptions.add(currentBlock.getIndex());
blockIndexes = new IntOpenHashSet();
exceptionMap.put(insn.getException().getIndex(), blockIndexes);
}
blockIndexes.add(currentBlock.getIndex());
} }
@Override @Override
public void visit(ConstructArrayInstruction insn) { public void visit(ConstructArrayInstruction insn) {
tasks.add(new Task(insn.getReceiver().getIndex(), 0, ValueType.arrayOf(insn.getItemType()).toString())); addType(insn.getReceiver().getIndex(), 0, ValueType.arrayOf(insn.getItemType()).toString());
} }
@Override @Override
public void visit(ConstructInstruction insn) { public void visit(ConstructInstruction insn) {
tasks.add(new Task(insn.getReceiver().getIndex(), 0, insn.getType())); addType(insn.getReceiver().getIndex(), 0, insn.getType());
} }
@Override @Override
@ -366,21 +608,21 @@ public class ClassInference {
for (int i = 0; i < insn.getDimensions().size(); ++i) { for (int i = 0; i < insn.getDimensions().size(); ++i) {
type = ValueType.arrayOf(type); type = ValueType.arrayOf(type);
} }
tasks.add(new Task(insn.getReceiver().getIndex(), 0, type.toString())); addType(insn.getReceiver().getIndex(), 0, type.toString());
} }
@Override @Override
public void visit(GetFieldInstruction insn) { public void visit(GetFieldInstruction insn) {
FieldDependencyInfo fieldDep = dependencyInfo.getField(insn.getField()); FieldDependencyInfo fieldDep = dependencyInfo.getField(insn.getField());
ValueDependencyInfo valueDep = fieldDep.getValue(); ValueDependencyInfo valueDep = fieldDep.getValue();
readValue(valueDep, insn.getReceiver(), tasks); readValue(valueDep, insn.getReceiver());
} }
@Override @Override
public void visit(PutFieldInstruction insn) { public void visit(PutFieldInstruction insn) {
FieldDependencyInfo fieldDep = dependencyInfo.getField(insn.getField()); FieldDependencyInfo fieldDep = dependencyInfo.getField(insn.getField());
ValueDependencyInfo valueDep = fieldDep.getValue(); ValueDependencyInfo valueDep = fieldDep.getValue();
writeValue(valueDep, insn.getValue(), tasks); writeValue(valueDep, insn.getValue());
} }
@Override @Override
@ -408,11 +650,12 @@ public class ClassInference {
if (insn.getValueToReturn() != null) { if (insn.getValueToReturn() != null) {
ValueDependencyInfo resultDependency = thisMethodDep.getResult(); ValueDependencyInfo resultDependency = thisMethodDep.getResult();
int resultDegree = 0; int resultDegree = 0;
while (resultDependency.hasArrayType()) { while (resultDependency.hasArrayType() && resultDegree <= MAX_DEGREE) {
resultDependency = resultDependency.getArrayItem(); resultDependency = resultDependency.getArrayItem();
for (String paramType : resultDependency.getTypes()) { for (String paramType : resultDependency.getTypes()) {
tasks.add(new Task(insn.getValueToReturn().getIndex(), ++resultDegree, paramType)); addType(insn.getValueToReturn().getIndex(), resultDegree, paramType);
} }
++resultDegree;
} }
} }
} }
@ -421,21 +664,26 @@ public class ClassInference {
public void visit(InvokeInstruction insn) { public void visit(InvokeInstruction insn) {
if (insn.getType() == InvocationType.VIRTUAL) { if (insn.getType() == InvocationType.VIRTUAL) {
int instance = insn.getInstance().getIndex(); int instance = insn.getInstance().getIndex();
List<VirtualCallSite> callSites = virtualCallSites.get(instance);
if (callSites == null) {
callSites = new ArrayList<>();
virtualCallSites.set(instance, callSites);
}
VirtualCallSite callSite = new VirtualCallSite(); VirtualCallSite callSite = new VirtualCallSite();
callSite.instance = instance;
callSite.method = insn.getMethod(); callSite.method = insn.getMethod();
callSite.arguments = new int[insn.getArguments().size()]; callSite.arguments = new int[insn.getArguments().size()];
for (int i = 0; i < insn.getArguments().size(); ++i) { for (int i = 0; i < insn.getArguments().size(); ++i) {
callSite.arguments[i] = insn.getArguments().get(i).getIndex(); callSite.arguments[i] = insn.getArguments().get(i).getIndex();
for (int j = 0; j <= MAX_DEGREE; ++j) {
getNodeTypes(packNodeAndDegree(callSite.arguments[i], j));
}
} }
callSite.receiver = insn.getReceiver() != null ? insn.getReceiver().getIndex() : -1; callSite.receiver = insn.getReceiver() != null ? insn.getReceiver().getIndex() : -1;
callSite.block = currentBlock.getIndex(); callSite.block = currentBlock.getIndex();
callSites.add(callSite); virtualCallSites.add(callSite);
if (insn.getReceiver() != null) {
for (int j = 1; j <= MAX_DEGREE; ++j) {
getNodeTypes(packNodeAndDegree(callSite.receiver, j));
}
}
return; return;
} }
@ -443,61 +691,95 @@ public class ClassInference {
MethodDependencyInfo methodDep = dependencyInfo.getMethod(insn.getMethod()); MethodDependencyInfo methodDep = dependencyInfo.getMethod(insn.getMethod());
if (methodDep != null) { if (methodDep != null) {
if (insn.getReceiver() != null) { if (insn.getReceiver() != null) {
readValue(methodDep.getResult(), insn.getReceiver(), tasks); readValue(methodDep.getResult(), insn.getReceiver());
} }
for (int i = 0; i < insn.getArguments().size(); ++i) { for (int i = 0; i < insn.getArguments().size(); ++i) {
writeValue(methodDep.getVariable(i + 1), insn.getArguments().get(i), tasks); writeValue(methodDep.getVariable(i + 1), insn.getArguments().get(i));
} }
for (String type : methodDep.getThrown().getTypes()) { for (String type : methodDep.getThrown().getTypes()) {
tasks.add(new Task(currentBlock.getIndex(), -1, type)); propagateException(type, currentBlock);
} }
} }
} }
} }
private static void readValue(ValueDependencyInfo valueDep, Variable receiver, Collection<Task> tasks) {
int depth = 0;
boolean hasArrayType;
do {
for (String type : valueDep.getTypes()) {
tasks.add(new Task(receiver.getIndex(), depth, type));
}
depth++;
hasArrayType = valueDep.hasArrayType();
valueDep = valueDep.getArrayItem();
} while (hasArrayType);
}
private static void writeValue(ValueDependencyInfo valueDep, Variable source, Collection<Task> tasks) {
int depth = 0;
while (valueDep.hasArrayType()) {
depth++;
valueDep = valueDep.getArrayItem();
for (String type : valueDep.getTypes()) {
tasks.add(new Task(source.getIndex(), depth, type));
}
}
}
static class Task {
int variable;
int degree;
String className;
Task(int variable, int degree, String className) {
this.variable = variable;
this.degree = degree;
this.className = className;
}
}
static class VirtualCallSite { static class VirtualCallSite {
int instance;
IntSet knownClasses = new IntOpenHashSet();
Set<MethodReference> resolvedMethods = new HashSet<>(); Set<MethodReference> resolvedMethods = new HashSet<>();
MethodReference method; MethodReference method;
int[] arguments; int[] arguments;
int receiver; int receiver;
int block; int block;
} }
void readValue(ValueDependencyInfo valueDep, Variable receiver) {
int depth = 0;
boolean hasArrayType;
do {
for (String type : valueDep.getTypes()) {
addType(receiver.getIndex(), depth, type);
}
depth++;
hasArrayType = valueDep.hasArrayType();
valueDep = valueDep.getArrayItem();
} while (hasArrayType && depth <= MAX_DEGREE);
}
void writeValue(ValueDependencyInfo valueDep, Variable source) {
int depth = 0;
while (valueDep.hasArrayType() && depth < MAX_DEGREE) {
depth++;
valueDep = valueDep.getArrayItem();
for (String type : valueDep.getTypes()) {
addType(source.getIndex(), depth, type);
}
}
}
void addType(int variable, int degree, String typeName) {
int entry = nodeMapping[packNodeAndDegree(variable, degree)];
if (getNodeTypes(entry).add(getTypeByName(typeName))) {
nodeChanged[entry] = true;
changed = true;
}
}
static int extractNode(int nodeAndDegree) {
return nodeAndDegree >>> 3;
}
static int extractDegree(int nodeAndDegree) {
return nodeAndDegree & 7;
}
static int packNodeAndDegree(int node, int degree) {
return (node << 3) | degree;
}
static long packTwoIntegers(int a, int b) {
return ((long) a << 32) | b;
}
static int unpackFirst(long pair) {
return (int) (pair >>> 32);
}
static int unpackSecond(long pair) {
return (int) pair;
}
static final class ValueCast {
final int fromVariable;
final int toVariable;
final ValueType targetType;
ValueCast(int fromVariable, int toVariable, ValueType targetType) {
this.fromVariable = fromVariable;
this.toVariable = toVariable;
this.targetType = targetType;
}
}
} }

View File

@ -137,7 +137,7 @@ class NullnessInformationBuilder {
sccIndexes = new int[program.variableCount()]; sccIndexes = new int[program.variableCount()];
if (assignmentGraph.size() > 0) { if (assignmentGraph.size() > 0) {
int[][] sccs = GraphUtils.findStronglyConnectedComponents(assignmentGraph, new int[]{0}); int[][] sccs = GraphUtils.findStronglyConnectedComponents(assignmentGraph);
for (int i = 0; i < sccs.length; ++i) { for (int i = 0; i < sccs.length; ++i) {
for (int sccNode : sccs[i]) { for (int sccNode : sccs[i]) {
sccIndexes[sccNode] = i + 1; sccIndexes[sccNode] = i + 1;

View File

@ -49,16 +49,12 @@ public class GraphTest {
builder.addEdge(12, 13); builder.addEdge(12, 13);
Graph graph = builder.build(); Graph graph = builder.build();
int[][] sccs = GraphUtils.findStronglyConnectedComponents(graph, new int[] { 0 }); int[][] sccs = GraphUtils.findStronglyConnectedComponents(graph);
sortSccs(sccs); sortSccs(sccs);
assertThat(sccs.length, is(6)); assertThat(sccs.length, is(2));
assertThat(sccs[0], is(new int[] { 0 })); assertThat(sccs[0], is(new int[] { 1, 2, 3, 4, 5, 6, 7, 8 }));
assertThat(sccs[1], is(new int[] { 1, 2, 3, 4, 5, 6, 7, 8 })); assertThat(sccs[1], is(new int[] { 11, 12 }));
assertThat(sccs[2], is(new int[] { 9 }));
assertThat(sccs[3], is(new int[] { 10 }));
assertThat(sccs[4], is(new int[] { 11, 12 }));
assertThat(sccs[5], is(new int[] { 13 }));
} }
@Test @Test
@ -77,7 +73,7 @@ public class GraphTest {
Graph graph = builder.build(); Graph graph = builder.build();
graph = GraphUtils.subgraph(graph, node -> node != 0); graph = GraphUtils.subgraph(graph, node -> node != 0);
int[][] sccs = GraphUtils.findStronglyConnectedComponents(graph, new int[] { 1, 2, 3 }); int[][] sccs = GraphUtils.findStronglyConnectedComponents(graph);
sortSccs(sccs); sortSccs(sccs);
assertThat(sccs.length, is(1)); assertThat(sccs.length, is(1));
@ -95,13 +91,11 @@ public class GraphTest {
Graph graph = builder.build(); Graph graph = builder.build();
graph = GraphUtils.subgraph(graph, filter); graph = GraphUtils.subgraph(graph, filter);
int[][] sccs = GraphUtils.findStronglyConnectedComponents(graph, new int[] { 0 }); int[][] sccs = GraphUtils.findStronglyConnectedComponents(graph);
sortSccs(sccs); sortSccs(sccs);
assertThat(sccs.length, is(3)); assertThat(sccs.length, is(1));
assertThat(sccs[0], is(new int[] { 0 })); assertThat(sccs[0], is(new int[] { 1, 3 }));
assertThat(sccs[1], is(new int[] { 1, 3 }));
assertThat(sccs[2], is(new int[] { 2 }));
} }
@Test @Test
@ -127,7 +121,7 @@ public class GraphTest {
Graph graph = builder.build(); Graph graph = builder.build();
graph = GraphUtils.subgraph(graph, node -> node != 0); graph = GraphUtils.subgraph(graph, node -> node != 0);
int[][] sccs = GraphUtils.findStronglyConnectedComponents(graph, new int[] { 1, 2, 3, 4 }); int[][] sccs = GraphUtils.findStronglyConnectedComponents(graph);
sortSccs(sccs); sortSccs(sccs);
assertThat(sccs.length, is(2)); assertThat(sccs.length, is(2));