wasm: fix issue in exception handling transformer

This commit is contained in:
Alexey Andreev 2023-09-24 20:15:44 +02:00
parent 603d7f1f88
commit c2c1d78f13
3 changed files with 40 additions and 13 deletions

View File

@ -163,6 +163,7 @@ import org.teavm.model.transformation.BoundCheckInsertion;
import org.teavm.model.transformation.ClassPatch; import org.teavm.model.transformation.ClassPatch;
import org.teavm.model.transformation.NullCheckInsertion; import org.teavm.model.transformation.NullCheckInsertion;
import org.teavm.model.util.AsyncMethodFinder; import org.teavm.model.util.AsyncMethodFinder;
import org.teavm.model.util.TransitionExtractor;
import org.teavm.runtime.Allocator; import org.teavm.runtime.Allocator;
import org.teavm.runtime.EventQueue; import org.teavm.runtime.EventQueue;
import org.teavm.runtime.ExceptionHandling; import org.teavm.runtime.ExceptionHandling;
@ -440,9 +441,25 @@ public class WasmTarget implements TeaVMTarget, TeaVMWasmHost {
.apply(program, method.getReference()); .apply(program, method.getReference());
checkTransformation.apply(program, method.getResultType()); checkTransformation.apply(program, method.getResultType());
shadowStackTransformer.apply(program, method); shadowStackTransformer.apply(program, method);
checkPhis(program, method);
writeBarrierInsertion.apply(program); writeBarrierInsertion.apply(program);
} }
private void checkPhis(Program program, MethodReader method) {
var transitionExtractor = new TransitionExtractor();
for (var block : program.getBasicBlocks()) {
for (var phi : block.getPhis()) {
for (var incoming : phi.getIncomings()) {
incoming.getSource().getLastInstruction().acceptVisitor(transitionExtractor);
if (!Arrays.asList(transitionExtractor.getTargets()).contains(block)) {
throw new RuntimeException("Method " + method.getReference() + ", block "
+ block.getIndex() + ", from " + incoming.getSource().getIndex());
}
}
}
}
}
@Override @Override
public void emit(ListableClassHolderSource classes, BuildTarget buildTarget, String outputName) public void emit(ListableClassHolderSource classes, BuildTarget buildTarget, String outputName)
throws IOException { throws IOException {

View File

@ -17,7 +17,6 @@ package org.teavm.model.lowlevel;
import com.carrotsearch.hppc.IntHashSet; import com.carrotsearch.hppc.IntHashSet;
import com.carrotsearch.hppc.IntObjectHashMap; import com.carrotsearch.hppc.IntObjectHashMap;
import com.carrotsearch.hppc.IntObjectMap;
import com.carrotsearch.hppc.IntSet; import com.carrotsearch.hppc.IntSet;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -151,16 +150,16 @@ public class ExceptionHandlingShadowStackContributor {
private int contributeToBasicBlock(BasicBlock block) { private int contributeToBasicBlock(BasicBlock block) {
int[] currentJointSources = new int[program.variableCount()]; int[] currentJointSources = new int[program.variableCount()];
IntObjectMap<int[]> jointReceiverMaps = new IntObjectHashMap<>(); var jointReceiverMaps = new IntObjectHashMap<int[]>();
Arrays.fill(currentJointSources, -1); Arrays.fill(currentJointSources, -1);
IntSet variablesDefinedHere = new IntHashSet(); IntSet variablesDefinedHere = new IntHashSet();
for (TryCatchBlock tryCatch : block.getTryCatchBlocks()) { for (var tryCatch : block.getTryCatchBlocks()) {
int[] jointReceiverMap = new int[program.variableCount()]; int[] jointReceiverMap = new int[program.variableCount()];
Arrays.fill(jointReceiverMap, -1); Arrays.fill(jointReceiverMap, -1);
for (Phi phi : tryCatch.getHandler().getPhis()) { for (Phi phi : tryCatch.getHandler().getPhis()) {
List<Variable> sourceVariables = phi.getIncomings().stream() var sourceVariables = phi.getIncomings().stream()
.filter(incoming -> incoming.getSource() == tryCatch.getProtectedBlock()) .filter(incoming -> incoming.getSource() == tryCatch.getProtectedBlock())
.map(incoming -> incoming.getValue()) .map(incoming -> incoming.getValue())
.collect(Collectors.toList()); .collect(Collectors.toList());
@ -198,7 +197,7 @@ public class ExceptionHandlingShadowStackContributor {
variablesDefinedHere.add(definedVar.getIndex()); variablesDefinedHere.add(definedVar.getIndex());
} }
DefinitionExtractor defExtractor = new DefinitionExtractor(); var defExtractor = new DefinitionExtractor();
List<BasicBlock> blocksToClearHandlers = new ArrayList<>(); List<BasicBlock> blocksToClearHandlers = new ArrayList<>();
blocksToClearHandlers.add(block); blocksToClearHandlers.add(block);
BasicBlock initialBlock = block; BasicBlock initialBlock = block;
@ -252,11 +251,11 @@ public class ExceptionHandlingShadowStackContributor {
} }
} }
CallSiteLocation[] locations = CallSiteLocation.fromTextLocation(insn.getLocation(), method); var locations = CallSiteLocation.fromTextLocation(insn.getLocation(), method);
CallSiteDescriptor callSite = new CallSiteDescriptor(callSiteIdGen++, locations); var callSite = new CallSiteDescriptor(callSiteIdGen++, locations);
callSites.add(callSite); callSites.add(callSite);
List<Instruction> pre = setLocation(getInstructionsBeforeCallSite(callSite), insn.getLocation()); var pre = setLocation(getInstructionsBeforeCallSite(callSite), insn.getLocation());
List<Instruction> post = getInstructionsAfterCallSite(initialBlock, block, next, callSite, var post = getInstructionsAfterCallSite(initialBlock, block, next, callSite,
currentJointSources, variablesDefinedHere); currentJointSources, variablesDefinedHere);
post = setLocation(post, insn.getLocation()); post = setLocation(post, insn.getLocation());
block.getLastInstruction().insertPreviousAll(pre); block.getLastInstruction().insertPreviousAll(pre);
@ -271,7 +270,7 @@ public class ExceptionHandlingShadowStackContributor {
} }
} }
fixOutgoingPhis(initialBlock, block, currentJointSources, variablesDefinedHere); removeOutgoingPhis(block);
for (BasicBlock blockToClear : blocksToClearHandlers) { for (BasicBlock blockToClear : blocksToClearHandlers) {
blockToClear.getTryCatchBlocks().clear(); blockToClear.getTryCatchBlocks().clear();
} }
@ -422,7 +421,7 @@ public class ExceptionHandlingShadowStackContributor {
List<Incoming> additionalIncomings = new ArrayList<>(); List<Incoming> additionalIncomings = new ArrayList<>();
for (int i = 0; i < phi.getIncomings().size(); i++) { for (int i = 0; i < phi.getIncomings().size(); i++) {
Incoming incoming = phi.getIncomings().get(i); Incoming incoming = phi.getIncomings().get(i);
if (incoming.getSource() != block || incoming.getSource() == newBlock) { if (incoming.getSource() != block) {
continue; continue;
} }
if (incoming.getValue().getIndex() == value) { if (incoming.getValue().getIndex() == value) {
@ -442,6 +441,18 @@ public class ExceptionHandlingShadowStackContributor {
} }
} }
private void removeOutgoingPhis(BasicBlock block) {
for (var tryCatch : block.getTryCatchBlocks()) {
for (var iterator = tryCatch.getHandler().getPhis().iterator(); iterator.hasNext();) {
var phi = iterator.next();
phi.getIncomings().removeIf(incoming -> incoming.getSource() == block);
if (phi.getIncomings().isEmpty()) {
iterator.remove();
}
}
}
}
private BasicBlock getDefaultExceptionHandler() { private BasicBlock getDefaultExceptionHandler() {
if (defaultExceptionHandler == null) { if (defaultExceptionHandler == null) {
defaultExceptionHandler = program.createBasicBlock(); defaultExceptionHandler = program.createBasicBlock();

View File

@ -55,8 +55,7 @@ public class ShadowStackTransformer {
boolean exceptions; boolean exceptions;
if (exceptionHandling) { if (exceptionHandling) {
List<CallSiteDescriptor> callSites = new ArrayList<>(); List<CallSiteDescriptor> callSites = new ArrayList<>();
ExceptionHandlingShadowStackContributor exceptionContributor = var exceptionContributor = new ExceptionHandlingShadowStackContributor(characteristics, callSites,
new ExceptionHandlingShadowStackContributor(characteristics, callSites,
method.getReference(), program); method.getReference(), program);
exceptionContributor.callSiteIdGen = callSiteIdGen; exceptionContributor.callSiteIdGen = callSiteIdGen;
exceptions = exceptionContributor.contribute(); exceptions = exceptionContributor.contribute();