WASM: improving algorithm that generates instructions to store variables in shadow stack

This commit is contained in:
Alexey Andreev 2016-09-21 00:47:55 +03:00
parent fcf0394214
commit fc3d36ec4c
7 changed files with 259 additions and 67 deletions

View File

@ -246,7 +246,6 @@ public final class WasmRuntime {
public static Address allocStack(int size) {
Address result = stack.add(4);
stack = result.add(size << 2);
fillZero(result, size << 2);
stack.putInt(size);
return result;
}

View File

@ -15,10 +15,6 @@
*/
package org.teavm.common;
/**
*
* @author Alexey Andreev
*/
public interface Graph {
int size();

View File

@ -21,10 +21,6 @@ import com.carrotsearch.hppc.cursors.IntCursor;
import java.util.ArrayList;
import java.util.List;
/**
*
* @author Alexey Andreev
*/
public class MutableDirectedGraph implements Graph {
private List<IntSet> successors = new ArrayList<>();
private List<IntSet> predecessors = new ArrayList<>();

View File

@ -15,12 +15,10 @@
*/
package org.teavm.common;
import com.carrotsearch.hppc.ObjectIntMap;
import com.carrotsearch.hppc.ObjectIntOpenHashMap;
import java.util.*;
/**
*
* @author Alexey Andreev
*/
public class MutableGraphNode {
private int tag;
final Map<MutableGraphNode, MutableGraphEdge> edges = new LinkedHashMap<>();
@ -77,4 +75,21 @@ public class MutableGraphNode {
}
return sb.toString();
}
public static Graph toGraph(List<MutableGraphNode> nodes) {
ObjectIntMap<MutableGraphNode> map = new ObjectIntOpenHashMap<>();
for (int i = 0; i < nodes.size(); ++i) {
map.put(nodes.get(i), i);
}
GraphBuilder builder = new GraphBuilder(nodes.size());
for (int i = 0; i < nodes.size(); ++i) {
for (MutableGraphEdge edge : nodes.get(i).getEdges()) {
int successor = map.get(edge.getSecond());
builder.addEdge(i, successor);
}
}
return builder.build();
}
}

View File

@ -15,13 +15,21 @@
*/
package org.teavm.model.lowlevel;
import com.carrotsearch.hppc.IntArrayList;
import com.carrotsearch.hppc.IntObjectMap;
import com.carrotsearch.hppc.IntObjectOpenHashMap;
import com.carrotsearch.hppc.cursors.ObjectCursor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.teavm.common.DominatorTree;
import org.teavm.common.Graph;
import org.teavm.common.GraphBuilder;
import org.teavm.common.GraphUtils;
import org.teavm.interop.NoGC;
import org.teavm.model.BasicBlock;
import org.teavm.model.ClassReader;
@ -44,6 +52,7 @@ import org.teavm.model.instructions.InvokeInstruction;
import org.teavm.model.instructions.JumpInstruction;
import org.teavm.model.instructions.RaiseInstruction;
import org.teavm.model.util.DefinitionExtractor;
import org.teavm.model.util.GraphColorer;
import org.teavm.model.util.LivenessAnalyzer;
import org.teavm.model.util.ProgramUtils;
import org.teavm.model.util.TypeInferer;
@ -59,14 +68,61 @@ public class GcRootMaintainingTransformer {
}
public void apply(Program program, MethodReader method) {
if (!requiresGc(method.getReference())) {
if (!requiresGC(method.getReference())) {
return;
}
List<IntObjectMap<BitSet>> liveInInformation = findCallSiteLiveIns(program, method);
int maxDepth = putLiveInGcRoots(program, liveInInformation);
if (maxDepth > 0) {
addStackAllocation(program, maxDepth);
addStackRelease(program, maxDepth);
Graph interferenceGraph = buildInterferenceGraph(liveInInformation, program);
boolean[] spilled = getAffectedVariables(liveInInformation, program);
int[] colors = new int[interferenceGraph.size()];
Arrays.fill(colors, -1);
new GraphColorer().colorize(interferenceGraph, colors);
int usedColors = 0;
for (int var = 0; var < colors.length; ++var) {
if (spilled[var]) {
usedColors = Math.max(usedColors, colors[var]);
colors[var]--;
}
}
if (usedColors == 0) {
return;
}
// If a variable is spilled to stack, then phi which input this variable also spilled to stack
// If all of phi inputs are spilled to stack, then we don't need to insert spilling instruction
// for this phi.
List<Set<Phi>> destinationPhis = getDestinationPhis(program);
int[] inputCount = getInputCount(program);
boolean[] autoSpilled = new boolean[spilled.length];
for (int i = 0; i < spilled.length; ++i) {
findAutoSpilledPhis(spilled, destinationPhis, inputCount, autoSpilled, i);
}
List<IntObjectMap<int[]>> liveInStores = reduceGcRootStores(program, usedColors, liveInInformation,
colors, autoSpilled);
putLiveInGcRoots(program, liveInStores);
addStackAllocation(program, usedColors);
addStackRelease(program, usedColors);
}
private void findAutoSpilledPhis(boolean[] spilled, List<Set<Phi>> destinationPhis, int[] inputCount,
boolean[] autoSpilled, int i) {
if (spilled[i]) {
Set<Phi> phis = destinationPhis.get(i);
if (phis != null) {
for (Phi phi : destinationPhis.get(i)) {
int destination = phi.getReceiver().getIndex();
autoSpilled[destination] = --inputCount[destination] == 0;
if (!spilled[destination]) {
spilled[destination] = true;
if (i > destination) {
findAutoSpilledPhis(spilled, destinationPhis, inputCount, autoSpilled, destination);
}
}
}
}
}
}
@ -102,7 +158,7 @@ public class GcRootMaintainingTransformer {
if (insn instanceof InvokeInstruction || insn instanceof InitClassInstruction
|| insn instanceof ConstructInstruction || insn instanceof ConstructArrayInstruction
|| insn instanceof CloneArrayInstruction || insn instanceof RaiseInstruction) {
if (insn instanceof InvokeInstruction && !requiresGc(((InvokeInstruction) insn).getMethod())) {
if (insn instanceof InvokeInstruction && !requiresGC(((InvokeInstruction) insn).getMethod())) {
continue;
}
@ -124,69 +180,202 @@ public class GcRootMaintainingTransformer {
return liveInInformation;
}
private int putLiveInGcRoots(Program program, List<IntObjectMap<BitSet>> liveInInformation) {
int maxDepth = 0;
for (IntObjectMap<BitSet> liveInsMap : liveInInformation) {
for (ObjectCursor<BitSet> liveIns : liveInsMap.values()) {
maxDepth = Math.max(maxDepth, liveIns.value.cardinality());
private Graph buildInterferenceGraph(List<IntObjectMap<BitSet>> liveInInformation, Program program) {
GraphBuilder builder = new GraphBuilder(program.variableCount());
for (IntObjectMap<BitSet> blockLiveIn : liveInInformation) {
for (ObjectCursor<BitSet> callSiteLiveIn : blockLiveIn.values()) {
BitSet liveVarsSet = callSiteLiveIn.value;
IntArrayList liveVars = new IntArrayList();
for (int i = liveVarsSet.nextSetBit(0); i >= 0; i = liveVarsSet.nextSetBit(i + 1)) {
liveVars.add(i);
}
int[] liveVarArray = liveVars.toArray();
for (int i = 0; i < liveVarArray.length - 1; ++i) {
for (int j = i + 1; j < liveVarArray.length; ++j) {
builder.addEdge(liveVarArray[i], liveVarArray[j]);
builder.addEdge(liveVarArray[j], liveVarArray[i]);
}
}
}
}
return builder.build();
}
private boolean[] getAffectedVariables(List<IntObjectMap<BitSet>> liveInInformation, Program program) {
boolean[] affectedVariables = new boolean[program.variableCount()];
for (IntObjectMap<BitSet> blockLiveIn : liveInInformation) {
for (ObjectCursor<BitSet> callSiteLiveIn : blockLiveIn.values()) {
BitSet liveVarsSet = callSiteLiveIn.value;
for (int i = liveVarsSet.nextSetBit(0); i >= 0; i = liveVarsSet.nextSetBit(i + 1)) {
affectedVariables[i] = true;
}
}
}
return affectedVariables;
}
private List<Set<Phi>> getDestinationPhis(Program program) {
List<Set<Phi>> destinationPhis = new ArrayList<>();
destinationPhis.addAll(Collections.nCopies(program.variableCount(), null));
for (int i = 0; i < program.basicBlockCount(); ++i) {
BasicBlock block = program.basicBlockAt(i);
List<Instruction> instructions = block.getInstructions();
IntObjectMap<BitSet> liveInsByIndex = liveInInformation.get(i);
for (int j = instructions.size() - 1; j >= 0; --j) {
BitSet liveIns = liveInsByIndex.get(j);
if (liveIns == null) {
continue;
for (Phi phi : block.getPhis()) {
for (Incoming incoming : phi.getIncomings()) {
Set<Phi> phis = destinationPhis.get(incoming.getValue().getIndex());
if (phis == null) {
phis = new HashSet<>();
destinationPhis.set(incoming.getValue().getIndex(), phis);
}
phis.add(phi);
}
storeLiveIns(block, j, liveIns, maxDepth);
}
}
return maxDepth;
return destinationPhis;
}
private void storeLiveIns(BasicBlock block, int index, BitSet liveIns, int maxDepth) {
private int[] getInputCount(Program program) {
int[] inputCount = new int[program.variableCount()];
for (int i = 0; i < program.basicBlockCount(); ++i) {
BasicBlock block = program.basicBlockAt(i);
for (Phi phi : block.getPhis()) {
inputCount[phi.getReceiver().getIndex()] = phi.getIncomings().size();
}
}
return inputCount;
}
private List<IntObjectMap<int[]>> reduceGcRootStores(Program program, int usedColors,
List<IntObjectMap<BitSet>> liveInInformation, int[] colors, boolean[] autoSpilled) {
class Step {
private final int node;
private final int[] slotStates = new int[usedColors];
private Step(int node) {
this.node = node;
}
}
List<IntObjectMap<int[]>> slotsToUpdate = new ArrayList<>();
for (int i = 0; i < program.basicBlockCount(); ++i) {
slotsToUpdate.add(new IntObjectOpenHashMap<>());
}
Graph cfg = ProgramUtils.buildControlFlowGraph(program);
DominatorTree dom = GraphUtils.buildDominatorTree(cfg);
Graph domGraph = GraphUtils.buildDominatorGraph(dom, cfg.size());
Step[] stack = new Step[program.basicBlockCount() * 2];
int head = 0;
Step start = new Step(0);
Arrays.fill(start.slotStates, usedColors);
stack[head++] = start;
while (head > 0) {
Step step = stack[--head];
int[] previousStates = step.slotStates;
int[] states = previousStates.clone();
IntObjectMap<BitSet> callSites = liveInInformation.get(step.node);
IntObjectMap<int[]> updatesByCallSite = slotsToUpdate.get(step.node);
int[] callSiteLocations = callSites.keys().toArray();
Arrays.sort(callSiteLocations);
for (int callSiteLocation : callSiteLocations) {
BitSet liveIns = callSites.get(callSiteLocation);
for (int liveVar = liveIns.nextSetBit(0); liveVar >= 0; liveVar = liveIns.nextSetBit(liveVar + 1)) {
int slot = colors[liveVar];
states[slot] = liveVar;
}
for (int slot = 0; slot < states.length; ++slot) {
if (states[slot] >= 0 && !liveIns.get(states[slot])) {
states[slot] = -1;
}
}
updatesByCallSite.put(callSiteLocation, compareStates(previousStates, states, autoSpilled));
previousStates = states;
states = states.clone();
}
for (int succ : domGraph.outgoingEdges(step.node)) {
Step next = new Step(succ);
System.arraycopy(states, 0, next.slotStates, 0, usedColors);
stack[head++] = next;
}
}
return slotsToUpdate;
}
private static int[] compareStates(int[] oldStates, int[] newStates, boolean[] autoSpilled) {
int[] comparison = new int[oldStates.length];
Arrays.fill(comparison, -2);
for (int i = 0; i < oldStates.length; ++i) {
if (oldStates[i] != newStates[i]) {
comparison[i] = newStates[i];
}
}
for (int i = 0; i < newStates.length; ++i) {
if (newStates[i] >= 0 && autoSpilled[newStates[i]]) {
comparison[i] = -2;
}
}
return comparison;
}
private void putLiveInGcRoots(Program program, List<IntObjectMap<int[]>> updateInformation) {
for (int i = 0; i < program.basicBlockCount(); ++i) {
BasicBlock block = program.basicBlockAt(i);
IntObjectMap<int[]> updatesByIndex = updateInformation.get(i);
int[] callSiteLocations = updatesByIndex.keys().toArray();
Arrays.sort(callSiteLocations);
for (int j = callSiteLocations.length - 1; j >= 0; --j) {
int callSiteLocation = callSiteLocations[j];
int[] updates = updatesByIndex.get(callSiteLocation);
storeLiveIns(block, callSiteLocation, updates);
}
}
}
private void storeLiveIns(BasicBlock block, int index, int[] updates) {
Program program = block.getProgram();
List<Instruction> instructions = block.getInstructions();
Instruction callInstruction = instructions.get(index);
List<Instruction> instructionsToAdd = new ArrayList<>();
int slot = 0;
for (int liveVar = liveIns.nextSetBit(0); liveVar >= 0; liveVar = liveIns.nextSetBit(liveVar + 1)) {
for (int slot = 0; slot < updates.length; ++slot) {
int var = updates[slot];
if (var == -2) {
continue;
}
Variable slotVar = program.createVariable();
IntegerConstantInstruction slotConstant = new IntegerConstantInstruction();
slotConstant.setReceiver(slotVar);
slotConstant.setConstant(slot++);
slotConstant.setConstant(slot);
slotConstant.setLocation(callInstruction.getLocation());
instructionsToAdd.add(slotConstant);
InvokeInstruction registerInvocation = new InvokeInstruction();
registerInvocation.setType(InvocationType.SPECIAL);
registerInvocation.setMethod(new MethodReference(Mutator.class, "registerGcRoot", int.class,
Object.class, void.class));
registerInvocation.getArguments().add(slotVar);
registerInvocation.getArguments().add(program.variableAt(liveVar));
if (var >= 0) {
registerInvocation.setMethod(new MethodReference(Mutator.class, "registerGcRoot", int.class,
Object.class, void.class));
registerInvocation.getArguments().add(program.variableAt(var));
} else {
registerInvocation.setMethod(new MethodReference(Mutator.class, "removeGcRoot", int.class,
void.class));
}
instructionsToAdd.add(registerInvocation);
}
while (slot < maxDepth) {
Variable slotVar = program.createVariable();
IntegerConstantInstruction slotConstant = new IntegerConstantInstruction();
slotConstant.setReceiver(slotVar);
slotConstant.setConstant(slot++);
slotConstant.setLocation(callInstruction.getLocation());
instructionsToAdd.add(slotConstant);
InvokeInstruction clearInvocation = new InvokeInstruction();
clearInvocation.setType(InvocationType.SPECIAL);
clearInvocation.setMethod(new MethodReference(Mutator.class, "removeGcRoot", int.class, void.class));
clearInvocation.getArguments().add(slotVar);
clearInvocation.setLocation(callInstruction.getLocation());
instructionsToAdd.add(clearInvocation);
}
instructions.addAll(index, instructionsToAdd);
}
@ -288,7 +477,7 @@ public class GcRootMaintainingTransformer {
}
}
private boolean requiresGc(MethodReference methodReference) {
private boolean requiresGC(MethodReference methodReference) {
ClassReader cls = classSource.get(methodReference.getClassName());
if (cls == null) {
return true;

View File

@ -18,16 +18,15 @@ package org.teavm.model.util;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
import org.teavm.common.Graph;
import org.teavm.common.IntegerArray;
import org.teavm.common.MutableGraphEdge;
import org.teavm.common.MutableGraphNode;
class GraphColorer {
public void colorize(List<MutableGraphNode> graph, int[] colors) {
public class GraphColorer {
public void colorize(Graph graph, int[] colors) {
colorize(graph, colors, new int[graph.size()], new String[graph.size()]);
}
public void colorize(List<MutableGraphNode> graph, int[] colors, int[] categories, String[] names) {
public void colorize(Graph graph, int[] colors, int[] categories, String[] names) {
IntegerArray colorCategories = new IntegerArray(graph.size());
List<String> colorNames = new ArrayList<>();
for (int i = 0; i < colors.length; ++i) {
@ -51,8 +50,7 @@ class GraphColorer {
}
usedColors.clear();
usedColors.set(0);
for (MutableGraphEdge edge : graph.get(v).getEdges()) {
int succ = edge.getSecond().getTag();
for (int succ : graph.outgoingEdges(v)) {
if (colors[succ] >= 0) {
usedColors.set(colors[succ]);
}
@ -82,7 +80,7 @@ class GraphColorer {
}
}
private int[] getOrdering(List<MutableGraphNode> graph) {
private int[] getOrdering(Graph graph) {
boolean[] visited = new boolean[graph.size()];
int[] ordering = new int[graph.size()];
int index = 0;
@ -104,8 +102,7 @@ class GraphColorer {
}
visited[v] = true;
ordering[index++] = v;
for (MutableGraphEdge edge : graph.get(v).getEdges()) {
int succ = edge.getSecond().getTag();
for (int succ : graph.outgoingEdges(v)) {
if (visited[succ]) {
continue;
}

View File

@ -68,7 +68,7 @@ public class RegisterAllocator {
}
int[] categories = getVariableCategories(program, method.getReference());
String[] names = getVariableNames(program);
colorer.colorize(interferenceGraph, colors, categories, names);
colorer.colorize(MutableGraphNode.toGraph(interferenceGraph), colors, categories, names);
int maxColor = 0;
for (int i = 0; i < colors.length; ++i) {