wasm gc: optimize null checks, casts and try/catch using branching instructions

This commit is contained in:
Alexey Andreev 2024-09-14 15:26:22 +02:00
parent 9aee15fa0f
commit a503333c1b
7 changed files with 93 additions and 39 deletions

View File

@ -143,7 +143,7 @@ public class DisassemblyCodeListener extends BaseDisassemblyListener implements
writer.write(" ");
writeType(sourceType);
writer.write(" ");
writeType(sourceType);
writeType(targetType);
writer.eol();
}

View File

@ -179,16 +179,6 @@ public class WasmGenerationVisitor extends BaseWasmGenerationVisitor {
return true;
}
@Override
protected void catchException(TextLocation location, List<WasmExpression> target, WasmLocal local,
String exceptionClass, WasmLocal exceptionVar) {
if (local != null) {
var save = new WasmSetLocal(local, new WasmGetLocal(exceptionVar));
save.setLocation(location);
target.add(save);
}
}
@Override
protected WasmType mapType(ValueType type) {
return WasmGeneratorUtil.mapType(type);

View File

@ -1300,12 +1300,16 @@ public abstract class BaseWasmGenerationVisitor implements StatementVisitor, Exp
for (int i = tryCatchStatements.size() - 1; i >= 0; --i) {
var tryCatch = tryCatchStatements.get(i);
var catchBlock = catchBlocks.get(i);
var blockType = mapType(tryCatch.getExceptionType() != null
? ValueType.object(tryCatch.getExceptionType())
: ValueType.object("java.lang.Throwable"));
currentBlock.setType(blockType);
if (tryCatch.getExceptionType() != null && !tryCatch.getExceptionType().equals(Throwable.class.getName())) {
var exceptionType = ValueType.object(tryCatch.getExceptionType());
var isMatched = generateInstanceOf(new WasmGetLocal(exceptionVar), exceptionType);
innerCatchBlock.getBody().add(new WasmBranch(isMatched, currentBlock));
checkExceptionType(tryCatch, exceptionVar, innerCatchBlock.getBody(), currentBlock);
} else {
innerCatchBlock.getBody().add(new WasmBreak(currentBlock));
var br = new WasmBreak(currentBlock);
br.setResult(new WasmGetLocal(exceptionVar));
innerCatchBlock.getBody().add(br);
catchesAll = true;
}
currentBlock = catchBlock;
@ -1320,12 +1324,16 @@ public abstract class BaseWasmGenerationVisitor implements StatementVisitor, Exp
for (int i = tryCatchStatements.size() - 1; i >= 0; --i) {
var tryCatch = tryCatchStatements.get(i);
var catchBlock = catchBlocks.get(i);
catchBlock.getBody().add(currentBlock);
var catchLocal = tryCatch.getExceptionVariable() != null
? localVar(tryCatch.getExceptionVariable())
: null;
catchException(null, catchBlock.getBody(), catchLocal, tryCatch.getExceptionType(), exceptionVar);
if (catchLocal != null) {
var save = new WasmSetLocal(localVar(tryCatch.getExceptionVariable()), currentBlock);
catchBlock.getBody().add(save);
} else {
catchBlock.getBody().add(new WasmDrop(currentBlock));
}
visitMany(tryCatch.getHandler(), catchBlock.getBody());
if (!catchBlock.isTerminating() && catchBlock != outerCatchBlock) {
catchBlock.getBody().add(new WasmBreak(outerCatchBlock));
@ -1337,8 +1345,12 @@ public abstract class BaseWasmGenerationVisitor implements StatementVisitor, Exp
tempVars.release(exceptionVar);
}
protected abstract void catchException(TextLocation location, List<WasmExpression> target, WasmLocal local,
String exceptionClass, WasmLocal exceptionVar);
protected void checkExceptionType(TryCatchStatement tryCatch, WasmLocal exceptionVar, List<WasmExpression> target,
WasmBlock targetBlock) {
var exceptionType = ValueType.object(tryCatch.getExceptionType());
var isMatched = generateInstanceOf(new WasmGetLocal(exceptionVar), exceptionType);
target.add(new WasmBranch(isMatched, targetBlock));
}
private void visitMany(List<Statement> statements, List<WasmExpression> target) {
var oldTarget = resultConsumer;

View File

@ -19,12 +19,15 @@ import java.util.List;
import java.util.function.Supplier;
import org.teavm.ast.ArrayType;
import org.teavm.ast.BinaryExpr;
import org.teavm.ast.CastExpr;
import org.teavm.ast.ConditionalExpr;
import org.teavm.ast.Expr;
import org.teavm.ast.InstanceOfExpr;
import org.teavm.ast.InvocationExpr;
import org.teavm.ast.InvocationType;
import org.teavm.ast.QualificationExpr;
import org.teavm.ast.SubscriptExpr;
import org.teavm.ast.TryCatchStatement;
import org.teavm.backend.wasm.BaseWasmFunctionRepository;
import org.teavm.backend.wasm.WasmFunctionTypes;
import org.teavm.backend.wasm.gc.PreciseTypeInference;
@ -50,6 +53,8 @@ import org.teavm.backend.wasm.model.expression.WasmBlock;
import org.teavm.backend.wasm.model.expression.WasmCall;
import org.teavm.backend.wasm.model.expression.WasmCallReference;
import org.teavm.backend.wasm.model.expression.WasmCast;
import org.teavm.backend.wasm.model.expression.WasmCastBranch;
import org.teavm.backend.wasm.model.expression.WasmCastCondition;
import org.teavm.backend.wasm.model.expression.WasmDrop;
import org.teavm.backend.wasm.model.expression.WasmExpression;
import org.teavm.backend.wasm.model.expression.WasmGetGlobal;
@ -68,9 +73,11 @@ import org.teavm.backend.wasm.model.expression.WasmSignedType;
import org.teavm.backend.wasm.model.expression.WasmStructGet;
import org.teavm.backend.wasm.model.expression.WasmStructNewDefault;
import org.teavm.backend.wasm.model.expression.WasmStructSet;
import org.teavm.backend.wasm.model.expression.WasmTest;
import org.teavm.backend.wasm.model.expression.WasmThrow;
import org.teavm.backend.wasm.model.expression.WasmUnreachable;
import org.teavm.model.ClassHierarchy;
import org.teavm.model.ElementModifier;
import org.teavm.model.FieldReference;
import org.teavm.model.MethodReference;
import org.teavm.model.TextLocation;
@ -245,18 +252,14 @@ public class WasmGCGenerationVisitor extends BaseWasmGenerationVisitor {
}
result.acceptVisitor(typeInference);
block.setType(typeInference.getResult());
var cachedValue = exprCache.create(result, typeInference.getResult(), location, block.getBody());
var check = new WasmNullBranch(WasmNullCondition.NOT_NULL, cachedValue.expr(), block);
check.setResult(cachedValue.expr());
block.getBody().add(new WasmDrop(check));
var check = new WasmNullBranch(WasmNullCondition.NOT_NULL, result, block);
block.getBody().add(check);
var callSiteId = generateCallSiteId(location);
callSiteId.generateRegister(block.getBody(), location);
generateThrowNPE(location, block.getBody());
callSiteId.generateThrow(block.getBody(), location);
cachedValue.release();
return block;
}
@ -382,6 +385,59 @@ public class WasmGCGenerationVisitor extends BaseWasmGenerationVisitor {
return supertypeCall;
}
@Override
public void visit(InstanceOfExpr expr) {
var type = expr.getType();
if (canCastNatively(type)) {
var wasmType = context.classInfoProvider().getClassInfo(type).getStructure().getNonNullReference();
acceptWithType(expr.getExpr(), type);
var wasmValue = result;
result.acceptVisitor(typeInference);
result = new WasmTest(wasmValue, wasmType);
result.setLocation(expr.getLocation());
} else {
super.visit(expr);
}
}
@Override
public void visit(CastExpr expr) {
var type = expr.getTarget();
if (!expr.isWeak() && canCastNatively(type)) {
var wasmType = context.classInfoProvider().getClassInfo(type).getType();
var block = new WasmBlock(false);
acceptWithType(expr.getValue(), type);
var wasmValue = result;
result.acceptVisitor(typeInference);
var sourceWasmType = (WasmType.Reference) typeInference.getResult();
if (sourceWasmType == null || !validateCastTypes(sourceWasmType, wasmType, expr.getLocation())) {
return;
}
block.setType(wasmType);
block.setLocation(expr.getLocation());
block.getBody().add(new WasmCastBranch(WasmCastCondition.SUCCESS, wasmValue, sourceWasmType,
wasmType, block));
generateThrowCCE(expr.getLocation(), block.getBody());
result = block;
} else {
super.visit(expr);
}
}
private boolean canCastNatively(ValueType type) {
if (type instanceof ValueType.Array) {
return true;
}
var className = ((ValueType.Object) type).getClassName();
var cls = context.classes().get(className);
if (cls == null) {
return false;
}
return !cls.hasModifier(ElementModifier.INTERFACE);
}
@Override
protected WasmExpression generateCast(WasmExpression value, WasmType targetType) {
return new WasmCast(value, (WasmType.Reference) targetType);
@ -435,18 +491,13 @@ public class WasmGCGenerationVisitor extends BaseWasmGenerationVisitor {
}
@Override
protected void catchException(TextLocation location, List<WasmExpression> target, WasmLocal local,
String exceptionClass, WasmLocal exceptionVar) {
if (local != null) {
WasmExpression exception = new WasmGetLocal(exceptionVar);
if (exceptionClass != null && !exceptionClass.equals("java.lang.Throwable")) {
exception = new WasmCast(exception, context.classInfoProvider().getClassInfo(exceptionClass)
.getStructure().getNonNullReference());
}
var save = new WasmSetLocal(local, exception);
save.setLocation(location);
target.add(save);
}
protected void checkExceptionType(TryCatchStatement tryCatch, WasmLocal exceptionVar, List<WasmExpression> target,
WasmBlock targetBlock) {
var wasmType = context.classInfoProvider().getClassInfo(tryCatch.getExceptionType()).getType();
var wasmSourceType = context.classInfoProvider().getClassInfo("java.lang.Throwable").getType();
var br = new WasmCastBranch(WasmCastCondition.SUCCESS, new WasmGetLocal(exceptionVar),
wasmSourceType, wasmType, targetBlock);
target.add(br);
}
@Override

View File

@ -30,6 +30,7 @@ public class WasmCastBranch extends WasmExpression {
WasmType.Reference type, WasmBlock target) {
this.condition = Objects.requireNonNull(condition);
this.value = Objects.requireNonNull(value);
this.sourceType = Objects.requireNonNull(sourceType);
this.type = Objects.requireNonNull(type);
this.target = Objects.requireNonNull(target);
}

View File

@ -816,9 +816,9 @@ public class CodeParser extends BaseSectionParser {
}
private void parseCastBranch(boolean success) {
var flags = reader.data[reader.ptr++];
var depth = readLEB();
var target = blockStack.get(blockStack.size() - depth - 1);
var flags = reader.data[reader.ptr++];
var sourceType = reader.readHeapType((flags & 1) != 0);
var targetType = reader.readHeapType((flags & 2) != 0);
codeListener.castBranch(success, depth, target.token, sourceType, targetType);

View File

@ -236,7 +236,6 @@ class WasmBinaryRenderingVisitor implements WasmExpressionVisitor {
writer.writeByte(25);
break;
}
writeLabel(expression.getTarget());
var flags = 0;
if (expression.getSourceType().isNullable()) {
flags |= 1;
@ -245,6 +244,7 @@ class WasmBinaryRenderingVisitor implements WasmExpressionVisitor {
flags |= 2;
}
writer.writeByte(flags);
writeLabel(expression.getTarget());
writer.writeHeapType(expression.getSourceType(), module);
writer.writeHeapType(expression.getType(), module);
popLocation();