From a503333c1bd592be0c7d4e8b14a1d9bc803fffac Mon Sep 17 00:00:00 2001 From: Alexey Andreev Date: Sat, 14 Sep 2024 15:26:22 +0200 Subject: [PATCH] wasm gc: optimize null checks, casts and try/catch using branching instructions --- .../wasm/disasm/DisassemblyCodeListener.java | 2 +- .../wasm/generate/WasmGenerationVisitor.java | 10 --- .../methods/BaseWasmGenerationVisitor.java | 28 ++++-- .../gc/methods/WasmGCGenerationVisitor.java | 87 +++++++++++++++---- .../wasm/model/expression/WasmCastBranch.java | 1 + .../teavm/backend/wasm/parser/CodeParser.java | 2 +- .../render/WasmBinaryRenderingVisitor.java | 2 +- 7 files changed, 93 insertions(+), 39 deletions(-) diff --git a/core/src/main/java/org/teavm/backend/wasm/disasm/DisassemblyCodeListener.java b/core/src/main/java/org/teavm/backend/wasm/disasm/DisassemblyCodeListener.java index ecfff78c4..f5cda47cf 100644 --- a/core/src/main/java/org/teavm/backend/wasm/disasm/DisassemblyCodeListener.java +++ b/core/src/main/java/org/teavm/backend/wasm/disasm/DisassemblyCodeListener.java @@ -143,7 +143,7 @@ public class DisassemblyCodeListener extends BaseDisassemblyListener implements writer.write(" "); writeType(sourceType); writer.write(" "); - writeType(sourceType); + writeType(targetType); writer.eol(); } diff --git a/core/src/main/java/org/teavm/backend/wasm/generate/WasmGenerationVisitor.java b/core/src/main/java/org/teavm/backend/wasm/generate/WasmGenerationVisitor.java index ddefe322e..7181ecac9 100644 --- a/core/src/main/java/org/teavm/backend/wasm/generate/WasmGenerationVisitor.java +++ b/core/src/main/java/org/teavm/backend/wasm/generate/WasmGenerationVisitor.java @@ -179,16 +179,6 @@ public class WasmGenerationVisitor extends BaseWasmGenerationVisitor { return true; } - @Override - protected void catchException(TextLocation location, List target, WasmLocal local, - String exceptionClass, WasmLocal exceptionVar) { - if (local != null) { - var save = new WasmSetLocal(local, new WasmGetLocal(exceptionVar)); - save.setLocation(location); - target.add(save); - } - } - @Override protected WasmType mapType(ValueType type) { return WasmGeneratorUtil.mapType(type); diff --git a/core/src/main/java/org/teavm/backend/wasm/generate/common/methods/BaseWasmGenerationVisitor.java b/core/src/main/java/org/teavm/backend/wasm/generate/common/methods/BaseWasmGenerationVisitor.java index 8042200b8..9a439d4e8 100644 --- a/core/src/main/java/org/teavm/backend/wasm/generate/common/methods/BaseWasmGenerationVisitor.java +++ b/core/src/main/java/org/teavm/backend/wasm/generate/common/methods/BaseWasmGenerationVisitor.java @@ -1300,12 +1300,16 @@ public abstract class BaseWasmGenerationVisitor implements StatementVisitor, Exp for (int i = tryCatchStatements.size() - 1; i >= 0; --i) { var tryCatch = tryCatchStatements.get(i); var catchBlock = catchBlocks.get(i); + var blockType = mapType(tryCatch.getExceptionType() != null + ? ValueType.object(tryCatch.getExceptionType()) + : ValueType.object("java.lang.Throwable")); + currentBlock.setType(blockType); if (tryCatch.getExceptionType() != null && !tryCatch.getExceptionType().equals(Throwable.class.getName())) { - var exceptionType = ValueType.object(tryCatch.getExceptionType()); - var isMatched = generateInstanceOf(new WasmGetLocal(exceptionVar), exceptionType); - innerCatchBlock.getBody().add(new WasmBranch(isMatched, currentBlock)); + checkExceptionType(tryCatch, exceptionVar, innerCatchBlock.getBody(), currentBlock); } else { - innerCatchBlock.getBody().add(new WasmBreak(currentBlock)); + var br = new WasmBreak(currentBlock); + br.setResult(new WasmGetLocal(exceptionVar)); + innerCatchBlock.getBody().add(br); catchesAll = true; } currentBlock = catchBlock; @@ -1320,12 +1324,16 @@ public abstract class BaseWasmGenerationVisitor implements StatementVisitor, Exp for (int i = tryCatchStatements.size() - 1; i >= 0; --i) { var tryCatch = tryCatchStatements.get(i); var catchBlock = catchBlocks.get(i); - catchBlock.getBody().add(currentBlock); var catchLocal = tryCatch.getExceptionVariable() != null ? localVar(tryCatch.getExceptionVariable()) : null; - catchException(null, catchBlock.getBody(), catchLocal, tryCatch.getExceptionType(), exceptionVar); + if (catchLocal != null) { + var save = new WasmSetLocal(localVar(tryCatch.getExceptionVariable()), currentBlock); + catchBlock.getBody().add(save); + } else { + catchBlock.getBody().add(new WasmDrop(currentBlock)); + } visitMany(tryCatch.getHandler(), catchBlock.getBody()); if (!catchBlock.isTerminating() && catchBlock != outerCatchBlock) { catchBlock.getBody().add(new WasmBreak(outerCatchBlock)); @@ -1337,8 +1345,12 @@ public abstract class BaseWasmGenerationVisitor implements StatementVisitor, Exp tempVars.release(exceptionVar); } - protected abstract void catchException(TextLocation location, List target, WasmLocal local, - String exceptionClass, WasmLocal exceptionVar); + protected void checkExceptionType(TryCatchStatement tryCatch, WasmLocal exceptionVar, List target, + WasmBlock targetBlock) { + var exceptionType = ValueType.object(tryCatch.getExceptionType()); + var isMatched = generateInstanceOf(new WasmGetLocal(exceptionVar), exceptionType); + target.add(new WasmBranch(isMatched, targetBlock)); + } private void visitMany(List statements, List target) { var oldTarget = resultConsumer; diff --git a/core/src/main/java/org/teavm/backend/wasm/generate/gc/methods/WasmGCGenerationVisitor.java b/core/src/main/java/org/teavm/backend/wasm/generate/gc/methods/WasmGCGenerationVisitor.java index 6299abb2c..76116fca0 100644 --- a/core/src/main/java/org/teavm/backend/wasm/generate/gc/methods/WasmGCGenerationVisitor.java +++ b/core/src/main/java/org/teavm/backend/wasm/generate/gc/methods/WasmGCGenerationVisitor.java @@ -19,12 +19,15 @@ import java.util.List; import java.util.function.Supplier; import org.teavm.ast.ArrayType; import org.teavm.ast.BinaryExpr; +import org.teavm.ast.CastExpr; import org.teavm.ast.ConditionalExpr; import org.teavm.ast.Expr; +import org.teavm.ast.InstanceOfExpr; import org.teavm.ast.InvocationExpr; import org.teavm.ast.InvocationType; import org.teavm.ast.QualificationExpr; import org.teavm.ast.SubscriptExpr; +import org.teavm.ast.TryCatchStatement; import org.teavm.backend.wasm.BaseWasmFunctionRepository; import org.teavm.backend.wasm.WasmFunctionTypes; import org.teavm.backend.wasm.gc.PreciseTypeInference; @@ -50,6 +53,8 @@ import org.teavm.backend.wasm.model.expression.WasmBlock; import org.teavm.backend.wasm.model.expression.WasmCall; import org.teavm.backend.wasm.model.expression.WasmCallReference; import org.teavm.backend.wasm.model.expression.WasmCast; +import org.teavm.backend.wasm.model.expression.WasmCastBranch; +import org.teavm.backend.wasm.model.expression.WasmCastCondition; import org.teavm.backend.wasm.model.expression.WasmDrop; import org.teavm.backend.wasm.model.expression.WasmExpression; import org.teavm.backend.wasm.model.expression.WasmGetGlobal; @@ -68,9 +73,11 @@ import org.teavm.backend.wasm.model.expression.WasmSignedType; import org.teavm.backend.wasm.model.expression.WasmStructGet; import org.teavm.backend.wasm.model.expression.WasmStructNewDefault; import org.teavm.backend.wasm.model.expression.WasmStructSet; +import org.teavm.backend.wasm.model.expression.WasmTest; import org.teavm.backend.wasm.model.expression.WasmThrow; import org.teavm.backend.wasm.model.expression.WasmUnreachable; import org.teavm.model.ClassHierarchy; +import org.teavm.model.ElementModifier; import org.teavm.model.FieldReference; import org.teavm.model.MethodReference; import org.teavm.model.TextLocation; @@ -245,18 +252,14 @@ public class WasmGCGenerationVisitor extends BaseWasmGenerationVisitor { } result.acceptVisitor(typeInference); block.setType(typeInference.getResult()); - var cachedValue = exprCache.create(result, typeInference.getResult(), location, block.getBody()); - - var check = new WasmNullBranch(WasmNullCondition.NOT_NULL, cachedValue.expr(), block); - check.setResult(cachedValue.expr()); - block.getBody().add(new WasmDrop(check)); + var check = new WasmNullBranch(WasmNullCondition.NOT_NULL, result, block); + block.getBody().add(check); var callSiteId = generateCallSiteId(location); callSiteId.generateRegister(block.getBody(), location); generateThrowNPE(location, block.getBody()); callSiteId.generateThrow(block.getBody(), location); - cachedValue.release(); return block; } @@ -382,6 +385,59 @@ public class WasmGCGenerationVisitor extends BaseWasmGenerationVisitor { return supertypeCall; } + @Override + public void visit(InstanceOfExpr expr) { + var type = expr.getType(); + if (canCastNatively(type)) { + var wasmType = context.classInfoProvider().getClassInfo(type).getStructure().getNonNullReference(); + acceptWithType(expr.getExpr(), type); + var wasmValue = result; + result.acceptVisitor(typeInference); + + result = new WasmTest(wasmValue, wasmType); + result.setLocation(expr.getLocation()); + } else { + super.visit(expr); + } + } + + @Override + public void visit(CastExpr expr) { + var type = expr.getTarget(); + if (!expr.isWeak() && canCastNatively(type)) { + var wasmType = context.classInfoProvider().getClassInfo(type).getType(); + var block = new WasmBlock(false); + acceptWithType(expr.getValue(), type); + var wasmValue = result; + result.acceptVisitor(typeInference); + var sourceWasmType = (WasmType.Reference) typeInference.getResult(); + if (sourceWasmType == null || !validateCastTypes(sourceWasmType, wasmType, expr.getLocation())) { + return; + } + + block.setType(wasmType); + block.setLocation(expr.getLocation()); + block.getBody().add(new WasmCastBranch(WasmCastCondition.SUCCESS, wasmValue, sourceWasmType, + wasmType, block)); + generateThrowCCE(expr.getLocation(), block.getBody()); + result = block; + } else { + super.visit(expr); + } + } + + private boolean canCastNatively(ValueType type) { + if (type instanceof ValueType.Array) { + return true; + } + var className = ((ValueType.Object) type).getClassName(); + var cls = context.classes().get(className); + if (cls == null) { + return false; + } + return !cls.hasModifier(ElementModifier.INTERFACE); + } + @Override protected WasmExpression generateCast(WasmExpression value, WasmType targetType) { return new WasmCast(value, (WasmType.Reference) targetType); @@ -435,18 +491,13 @@ public class WasmGCGenerationVisitor extends BaseWasmGenerationVisitor { } @Override - protected void catchException(TextLocation location, List target, WasmLocal local, - String exceptionClass, WasmLocal exceptionVar) { - if (local != null) { - WasmExpression exception = new WasmGetLocal(exceptionVar); - if (exceptionClass != null && !exceptionClass.equals("java.lang.Throwable")) { - exception = new WasmCast(exception, context.classInfoProvider().getClassInfo(exceptionClass) - .getStructure().getNonNullReference()); - } - var save = new WasmSetLocal(local, exception); - save.setLocation(location); - target.add(save); - } + protected void checkExceptionType(TryCatchStatement tryCatch, WasmLocal exceptionVar, List target, + WasmBlock targetBlock) { + var wasmType = context.classInfoProvider().getClassInfo(tryCatch.getExceptionType()).getType(); + var wasmSourceType = context.classInfoProvider().getClassInfo("java.lang.Throwable").getType(); + var br = new WasmCastBranch(WasmCastCondition.SUCCESS, new WasmGetLocal(exceptionVar), + wasmSourceType, wasmType, targetBlock); + target.add(br); } @Override diff --git a/core/src/main/java/org/teavm/backend/wasm/model/expression/WasmCastBranch.java b/core/src/main/java/org/teavm/backend/wasm/model/expression/WasmCastBranch.java index 91d40c108..5bf537cb6 100644 --- a/core/src/main/java/org/teavm/backend/wasm/model/expression/WasmCastBranch.java +++ b/core/src/main/java/org/teavm/backend/wasm/model/expression/WasmCastBranch.java @@ -30,6 +30,7 @@ public class WasmCastBranch extends WasmExpression { WasmType.Reference type, WasmBlock target) { this.condition = Objects.requireNonNull(condition); this.value = Objects.requireNonNull(value); + this.sourceType = Objects.requireNonNull(sourceType); this.type = Objects.requireNonNull(type); this.target = Objects.requireNonNull(target); } diff --git a/core/src/main/java/org/teavm/backend/wasm/parser/CodeParser.java b/core/src/main/java/org/teavm/backend/wasm/parser/CodeParser.java index 70c62deac..b97a76029 100644 --- a/core/src/main/java/org/teavm/backend/wasm/parser/CodeParser.java +++ b/core/src/main/java/org/teavm/backend/wasm/parser/CodeParser.java @@ -816,9 +816,9 @@ public class CodeParser extends BaseSectionParser { } private void parseCastBranch(boolean success) { + var flags = reader.data[reader.ptr++]; var depth = readLEB(); var target = blockStack.get(blockStack.size() - depth - 1); - var flags = reader.data[reader.ptr++]; var sourceType = reader.readHeapType((flags & 1) != 0); var targetType = reader.readHeapType((flags & 2) != 0); codeListener.castBranch(success, depth, target.token, sourceType, targetType); diff --git a/core/src/main/java/org/teavm/backend/wasm/render/WasmBinaryRenderingVisitor.java b/core/src/main/java/org/teavm/backend/wasm/render/WasmBinaryRenderingVisitor.java index 2647b68d1..5ce43a9a7 100644 --- a/core/src/main/java/org/teavm/backend/wasm/render/WasmBinaryRenderingVisitor.java +++ b/core/src/main/java/org/teavm/backend/wasm/render/WasmBinaryRenderingVisitor.java @@ -236,7 +236,6 @@ class WasmBinaryRenderingVisitor implements WasmExpressionVisitor { writer.writeByte(25); break; } - writeLabel(expression.getTarget()); var flags = 0; if (expression.getSourceType().isNullable()) { flags |= 1; @@ -245,6 +244,7 @@ class WasmBinaryRenderingVisitor implements WasmExpressionVisitor { flags |= 2; } writer.writeByte(flags); + writeLabel(expression.getTarget()); writer.writeHeapType(expression.getSourceType(), module); writer.writeHeapType(expression.getType(), module); popLocation();