From 7784969bb8c25f6c6658729505b0476723a7d4ec Mon Sep 17 00:00:00 2001
From: Alexey Andreev <konsoletyper@gmail.com>
Date: Sun, 15 Sep 2024 20:35:07 +0200
Subject: [PATCH] wasm gc: fix issues with casts

---
 .../org/teavm/backend/wasm/WasmGCTarget.java  |   8 +-
 .../wasm/disasm/DisassemblyCodeListener.java  |   2 +-
 .../wasm/generate/WasmGenerationVisitor.java  |  15 ---
 .../methods/BaseWasmGenerationVisitor.java    |  17 +--
 .../gc/classes/WasmGCClassGenerator.java      |  54 ++++----
 .../generate/gc/classes/WasmGCClassInfo.java  |   2 +
 .../WasmGCNewArrayFunctionGenerator.java      |  25 ++--
 .../gc/methods/WasmGCGenerationVisitor.java   | 119 ++++++++++--------
 8 files changed, 125 insertions(+), 117 deletions(-)

diff --git a/core/src/main/java/org/teavm/backend/wasm/WasmGCTarget.java b/core/src/main/java/org/teavm/backend/wasm/WasmGCTarget.java
index d04b29e11..88b281251 100644
--- a/core/src/main/java/org/teavm/backend/wasm/WasmGCTarget.java
+++ b/core/src/main/java/org/teavm/backend/wasm/WasmGCTarget.java
@@ -121,9 +121,15 @@ public class WasmGCTarget implements TeaVMTarget, TeaVMWasmGCHost {
     }
 
     @Override
-    public void beforeOptimizations(Program program, MethodReader method) {
+    public void beforeInlining(Program program, MethodReader method) {
         if (strict) {
             nullCheckInsertion.transformProgram(program, method.getReference());
+        }
+    }
+
+    @Override
+    public void beforeOptimizations(Program program, MethodReader method) {
+        if (strict) {
             boundCheckInsertion.transformProgram(program, method.getReference());
         }
     }
diff --git a/core/src/main/java/org/teavm/backend/wasm/disasm/DisassemblyCodeListener.java b/core/src/main/java/org/teavm/backend/wasm/disasm/DisassemblyCodeListener.java
index f5cda47cf..44f17a793 100644
--- a/core/src/main/java/org/teavm/backend/wasm/disasm/DisassemblyCodeListener.java
+++ b/core/src/main/java/org/teavm/backend/wasm/disasm/DisassemblyCodeListener.java
@@ -782,7 +782,7 @@ public class DisassemblyCodeListener extends BaseDisassemblyListener implements
 
     @Override
     public void cast(WasmHollowType.Reference type) {
-        writer.address().write("ref.cast (ref ");
+        writer.address().write("ref.cast ");
         writeType(type);
         writer.eol();
     }
diff --git a/core/src/main/java/org/teavm/backend/wasm/generate/WasmGenerationVisitor.java b/core/src/main/java/org/teavm/backend/wasm/generate/WasmGenerationVisitor.java
index 7181ecac9..61cb640ef 100644
--- a/core/src/main/java/org/teavm/backend/wasm/generate/WasmGenerationVisitor.java
+++ b/core/src/main/java/org/teavm/backend/wasm/generate/WasmGenerationVisitor.java
@@ -164,21 +164,6 @@ public class WasmGenerationVisitor extends BaseWasmGenerationVisitor {
         super.visit(expr);
     }
 
-    @Override
-    protected WasmExpression generateCast(WasmExpression value, WasmType targetType) {
-        return value;
-    }
-
-    @Override
-    protected WasmType mapCastSourceType(WasmType type) {
-        return type;
-    }
-
-    @Override
-    protected boolean validateCastTypes(WasmType sourceType, WasmType targetType, TextLocation location) {
-        return true;
-    }
-
     @Override
     protected WasmType mapType(ValueType type) {
         return WasmGeneratorUtil.mapType(type);
diff --git a/core/src/main/java/org/teavm/backend/wasm/generate/common/methods/BaseWasmGenerationVisitor.java b/core/src/main/java/org/teavm/backend/wasm/generate/common/methods/BaseWasmGenerationVisitor.java
index 9a439d4e8..d73e7e073 100644
--- a/core/src/main/java/org/teavm/backend/wasm/generate/common/methods/BaseWasmGenerationVisitor.java
+++ b/core/src/main/java/org/teavm/backend/wasm/generate/common/methods/BaseWasmGenerationVisitor.java
@@ -1186,19 +1186,13 @@ public abstract class BaseWasmGenerationVisitor implements StatementVisitor, Exp
     public void visit(CastExpr expr) {
         var wasmTargetType = mapType(expr.getTarget());
         acceptWithType(expr.getValue(), expr.getTarget());
+        result.acceptVisitor(typeInference);
+        var wasmSourceType = typeInference.getResult();
         if (!expr.isWeak()) {
-            result.acceptVisitor(typeInference);
-            var wasmSourceType = typeInference.getResult();
             if (wasmSourceType == null) {
                 return;
             }
 
-            wasmSourceType = mapCastSourceType(wasmSourceType);
-
-            if (!validateCastTypes(wasmSourceType, wasmTargetType, expr.getLocation())) {
-                return;
-            }
-
             var block = new WasmBlock(false);
             block.setType(wasmSourceType);
             block.setLocation(expr.getLocation());
@@ -1223,16 +1217,9 @@ public abstract class BaseWasmGenerationVisitor implements StatementVisitor, Exp
             valueToCast.release();
             result = block;
         }
-        result = generateCast(result, wasmTargetType);
         result.setLocation(expr.getLocation());
     }
 
-    protected abstract WasmType mapCastSourceType(WasmType type);
-
-    protected abstract boolean validateCastTypes(WasmType sourceType, WasmType targetType, TextLocation location);
-
-    protected abstract WasmExpression generateCast(WasmExpression value, WasmType targetType);
-
     @Override
     public void visit(InitClassStatement statement) {
         if (needsClassInitializer(statement.getClassName())) {
diff --git a/core/src/main/java/org/teavm/backend/wasm/generate/gc/classes/WasmGCClassGenerator.java b/core/src/main/java/org/teavm/backend/wasm/generate/gc/classes/WasmGCClassGenerator.java
index 3b7e64364..f740a41f9 100644
--- a/core/src/main/java/org/teavm/backend/wasm/generate/gc/classes/WasmGCClassGenerator.java
+++ b/core/src/main/java/org/teavm/backend/wasm/generate/gc/classes/WasmGCClassGenerator.java
@@ -108,7 +108,7 @@ public class WasmGCClassGenerator implements WasmGCClassInfoProvider, WasmGCInit
     private WasmGCVirtualTableProvider virtualTables;
     private BaseWasmFunctionRepository functionProvider;
     private Map<ValueType, WasmGCClassInfo> classInfoMap = new LinkedHashMap<>();
-    private Queue<WasmGCClassInfo> classInfoQueue = new ArrayDeque<>();
+    private Queue<Runnable> queue = new ArrayDeque<>();
     private ObjectIntMap<FieldReference> fieldIndexes = new ObjectIntHashMap<>();
     private Map<FieldReference, WasmGlobal> staticFieldLocations = new HashMap<>();
     private List<Consumer<WasmFunction>> staticFieldInitializers = new ArrayList<>();
@@ -171,7 +171,7 @@ public class WasmGCClassGenerator implements WasmGCClassInfoProvider, WasmGCInit
         standardClasses = new WasmGCStandardClasses(this);
         strings = new WasmGCStringPool(standardClasses, module, functionProvider, names, functionTypes);
         supertypeGenerator = new WasmGCSupertypeFunctionGenerator(module, this, names, tagRegistry, functionTypes);
-        newArrayGenerator = new WasmGCNewArrayFunctionGenerator(module, functionTypes, this, names);
+        newArrayGenerator = new WasmGCNewArrayFunctionGenerator(module, functionTypes, this, names, queue);
         typeMapper = new WasmGCTypeMapper(classSource, this, functionTypes, module);
         var customTypeMapperFactoryContext = customTypeMapperFactoryContext();
         typeMapper.setCustomTypeMappers(customTypeMapperFactories.stream()
@@ -213,13 +213,12 @@ public class WasmGCClassGenerator implements WasmGCClassInfoProvider, WasmGCInit
     }
 
     public boolean process() {
-        if (classInfoQueue.isEmpty()) {
+        if (queue.isEmpty()) {
             return false;
         }
-        while (!classInfoQueue.isEmpty()) {
-            var classInfo = classInfoQueue.remove();
-            classInfo.initializer.accept(initializerFunctionStatements);
-            classInfo.initializer = null;
+        while (!queue.isEmpty()) {
+            var action = queue.remove();
+            action.run();
             initStructures();
         }
         return true;
@@ -245,20 +244,13 @@ public class WasmGCClassGenerator implements WasmGCClassInfoProvider, WasmGCInit
         function.getBody().addAll(initializerFunctionStatements);
         initializerFunctionStatements.clear();
         for (var classInfo : classInfoMap.values()) {
-            var req = metadataRequirements.getInfo(classInfo.getValueType());
-            if (req != null) {
-                if (req.isAssignable()) {
-                    var supertypeFunction = supertypeGenerator.getIsSupertypeFunction(classInfo.getValueType());
-                    supertypeFunction.setReferenced(true);
-                    function.getBody().add(setClassField(classInfo, classSupertypeFunctionOffset,
-                            new WasmFunctionReference(supertypeFunction)));
-                }
-                if (req.newArray()) {
-                    var newArrayFunction = getArrayConstructor(ValueType.arrayOf(classInfo.getValueType()));
-                    newArrayFunction.setReferenced(true);
-                    function.getBody().add(setClassField(classInfo, classNewArrayOffset,
-                            new WasmFunctionReference(newArrayFunction)));
-                }
+            if (classInfo.supertypeFunction != null) {
+                function.getBody().add(setClassField(classInfo, classSupertypeFunctionOffset,
+                        new WasmFunctionReference(classInfo.supertypeFunction)));
+            }
+            if (classInfo.initArrayFunction != null) {
+                function.getBody().add(setClassField(classInfo, classNewArrayOffset,
+                        new WasmFunctionReference(classInfo.initArrayFunction)));
             }
         }
         for (var consumer : staticFieldInitializers) {
@@ -271,7 +263,11 @@ public class WasmGCClassGenerator implements WasmGCClassInfoProvider, WasmGCInit
         var classInfo = classInfoMap.get(type);
         if (classInfo == null) {
             classInfo = new WasmGCClassInfo(type);
-            classInfoQueue.add(classInfo);
+            var finalClassInfo = classInfo;
+            queue.add(() -> {
+                finalClassInfo.initializer.accept(initializerFunctionStatements);
+                finalClassInfo.initializer = null;
+            });
             classInfoMap.put(type, classInfo);
             WasmGCVirtualTable virtualTable = null;
             if (!(type instanceof ValueType.Primitive)) {
@@ -284,7 +280,6 @@ public class WasmGCClassGenerator implements WasmGCClassInfoProvider, WasmGCInit
                     isInterface = true;
                     classInfo.structure = standardClasses.objectClass().structure;
                 } else {
-                    var finalClassInfo = classInfo;
                     if (type instanceof ValueType.Array) {
                         var itemType = ((ValueType.Array) type).getItemType();
                         if (!(itemType instanceof ValueType.Primitive) && !itemType.equals(OBJECT_TYPE)) {
@@ -346,6 +341,18 @@ public class WasmGCClassGenerator implements WasmGCClassInfoProvider, WasmGCInit
             } else if (type instanceof ValueType.Object) {
                 initRegularClass(classInfo, virtualTable, classStructure, ((ValueType.Object) type).getClassName());
             }
+            var req = metadataRequirements.getInfo(type);
+            if (req != null) {
+                if (req.newArray()) {
+                    classInfo.initArrayFunction = getArrayConstructor(ValueType.arrayOf(classInfo.getValueType()));
+                    classInfo.initArrayFunction.setReferenced(true);
+                }
+                if (req.isAssignable()) {
+                    var supertypeFunction = supertypeGenerator.getIsSupertypeFunction(classInfo.getValueType());
+                    supertypeFunction.setReferenced(true);
+                    classInfo.supertypeFunction = supertypeFunction;
+                }
+            }
         }
         return classInfo;
     }
@@ -490,6 +497,7 @@ public class WasmGCClassGenerator implements WasmGCClassInfoProvider, WasmGCInit
     private void initRegularClass(WasmGCClassInfo classInfo, WasmGCVirtualTable virtualTable,
             WasmStructure classStructure, String name) {
         var cls = classSource.get(name);
+
         if (classInitializerInfo.isDynamicInitializer(name)) {
             if (cls != null && cls.getMethod(CLINIT_METHOD_DESC) != null) {
                 var clinitType = functionTypes.of(null);
diff --git a/core/src/main/java/org/teavm/backend/wasm/generate/gc/classes/WasmGCClassInfo.java b/core/src/main/java/org/teavm/backend/wasm/generate/gc/classes/WasmGCClassInfo.java
index 1eb8bf7b3..f2a4b396e 100644
--- a/core/src/main/java/org/teavm/backend/wasm/generate/gc/classes/WasmGCClassInfo.java
+++ b/core/src/main/java/org/teavm/backend/wasm/generate/gc/classes/WasmGCClassInfo.java
@@ -33,6 +33,8 @@ public class WasmGCClassInfo {
     WasmGlobal initializerPointer;
     Consumer<List<WasmExpression>> initializer;
     WasmFunction newArrayFunction;
+    WasmFunction initArrayFunction;
+    WasmFunction supertypeFunction;
 
     WasmGCClassInfo(ValueType valueType) {
         this.valueType = valueType;
diff --git a/core/src/main/java/org/teavm/backend/wasm/generate/gc/classes/WasmGCNewArrayFunctionGenerator.java b/core/src/main/java/org/teavm/backend/wasm/generate/gc/classes/WasmGCNewArrayFunctionGenerator.java
index 0dee70222..c89931de6 100644
--- a/core/src/main/java/org/teavm/backend/wasm/generate/gc/classes/WasmGCNewArrayFunctionGenerator.java
+++ b/core/src/main/java/org/teavm/backend/wasm/generate/gc/classes/WasmGCNewArrayFunctionGenerator.java
@@ -16,6 +16,7 @@
 package org.teavm.backend.wasm.generate.gc.classes;
 
 import java.util.List;
+import java.util.Queue;
 import org.teavm.backend.wasm.WasmFunctionTypes;
 import org.teavm.backend.wasm.generate.TemporaryVariablePool;
 import org.teavm.backend.wasm.generate.gc.WasmGCNameProvider;
@@ -35,13 +36,16 @@ class WasmGCNewArrayFunctionGenerator {
     private WasmGCClassInfoProvider classInfoProvider;
     private WasmFunctionType newArrayFunctionType;
     private WasmGCNameProvider names;
+    private Queue<Runnable> queue;
 
     WasmGCNewArrayFunctionGenerator(WasmModule module, WasmFunctionTypes functionTypes,
-            WasmGCClassInfoProvider classInfoProvider, WasmGCNameProvider names) {
+            WasmGCClassInfoProvider classInfoProvider, WasmGCNameProvider names,
+            Queue<Runnable> queue) {
         this.module = module;
         this.functionTypes = functionTypes;
         this.classInfoProvider = classInfoProvider;
         this.names = names;
+        this.queue = queue;
     }
 
     WasmFunction generateNewArrayFunction(ValueType itemType) {
@@ -53,14 +57,17 @@ class WasmGCNewArrayFunctionGenerator {
         var function = new WasmFunction(functionType);
         function.setName(names.topLevel("Array<" + names.suggestForType(itemType) + ">@new"));
         module.functions.add(function);
-        var sizeLocal = new WasmLocal(WasmType.INT32, "length");
-        function.add(sizeLocal);
-        var tempVars = new TemporaryVariablePool(function);
-        var genUtil = new WasmGCGenerationUtil(classInfoProvider, tempVars);
-        var targetVar = new WasmLocal(classInfo.getType(), "result");
-        function.add(targetVar);
-        genUtil.allocateArray(itemType, () -> new WasmGetLocal(sizeLocal), null, targetVar, function.getBody());
-        function.getBody().add(new WasmReturn(new WasmGetLocal(targetVar)));
+
+        queue.add(() -> {
+            var sizeLocal = new WasmLocal(WasmType.INT32, "length");
+            function.add(sizeLocal);
+            var tempVars = new TemporaryVariablePool(function);
+            var genUtil = new WasmGCGenerationUtil(classInfoProvider, tempVars);
+            var targetVar = new WasmLocal(classInfo.getType(), "result");
+            function.add(targetVar);
+            genUtil.allocateArray(itemType, () -> new WasmGetLocal(sizeLocal), null, targetVar, function.getBody());
+            function.getBody().add(new WasmReturn(new WasmGetLocal(targetVar)));
+        });
         return function;
     }
 
diff --git a/core/src/main/java/org/teavm/backend/wasm/generate/gc/methods/WasmGCGenerationVisitor.java b/core/src/main/java/org/teavm/backend/wasm/generate/gc/methods/WasmGCGenerationVisitor.java
index b2515e903..f448479d7 100644
--- a/core/src/main/java/org/teavm/backend/wasm/generate/gc/methods/WasmGCGenerationVisitor.java
+++ b/core/src/main/java/org/teavm/backend/wasm/generate/gc/methods/WasmGCGenerationVisitor.java
@@ -53,6 +53,7 @@ import org.teavm.backend.wasm.model.expression.WasmArrayGet;
 import org.teavm.backend.wasm.model.expression.WasmArrayLength;
 import org.teavm.backend.wasm.model.expression.WasmArraySet;
 import org.teavm.backend.wasm.model.expression.WasmBlock;
+import org.teavm.backend.wasm.model.expression.WasmBranch;
 import org.teavm.backend.wasm.model.expression.WasmCall;
 import org.teavm.backend.wasm.model.expression.WasmCallReference;
 import org.teavm.backend.wasm.model.expression.WasmCast;
@@ -443,26 +444,76 @@ public class WasmGCGenerationVisitor extends BaseWasmGenerationVisitor {
 
     @Override
     public void visit(CastExpr expr) {
-        var type = expr.getTarget();
-        if (!expr.isWeak() && canCastNatively(type)) {
-            var wasmType = context.classInfoProvider().getClassInfo(type).getType();
-            var block = new WasmBlock(false);
-            acceptWithType(expr.getValue(), type);
-            var wasmValue = result;
-            result.acceptVisitor(typeInference);
-            var sourceWasmType = (WasmType.Reference) typeInference.getResult();
-            if (sourceWasmType == null || !validateCastTypes(sourceWasmType, wasmType, expr.getLocation())) {
+        var needsCast = true;
+        acceptWithType(expr.getValue(), expr.getTarget());
+        result.acceptVisitor(typeInference);
+        var sourceType = (WasmType.Reference) typeInference.getResult();
+        if (sourceType == null) {
+            return;
+        }
+
+        var targetType = (WasmType.Reference) context.typeMapper().mapType(expr.getTarget());
+        WasmStructure targetStruct = null;
+        if (targetType instanceof WasmType.CompositeReference) {
+            var targetComposite = ((WasmType.CompositeReference) targetType).composite;
+            if (targetComposite instanceof WasmStructure) {
+                targetStruct = (WasmStructure) targetComposite;
+            }
+        }
+
+        var canInsertCast = true;
+        if (targetStruct != null && sourceType instanceof WasmType.CompositeReference) {
+            var sourceComposite = (WasmType.CompositeReference) sourceType;
+            if (!sourceType.isNullable()) {
+                sourceType = sourceComposite.composite.getReference();
+            }
+            var sourceStruct = (WasmStructure) sourceComposite.composite;
+            if (targetStruct.isSupertypeOf(sourceStruct)) {
+                canInsertCast = false;
+            } else if (!sourceStruct.isSupertypeOf(targetStruct)) {
+                var block = new WasmBlock(false);
+                block.setLocation(expr.getLocation());
+                block.getBody().add(result);
+                block.getBody().add(new WasmUnreachable());
+                result = block;
                 return;
             }
+        }
 
-            block.setType(wasmType);
+        if (!expr.isWeak()) {
+            result.acceptVisitor(typeInference);
+
+            var block = new WasmBlock(false);
             block.setLocation(expr.getLocation());
-            block.getBody().add(new WasmCastBranch(WasmCastCondition.SUCCESS, wasmValue, sourceWasmType,
-                    wasmType, block));
+            block.setType(targetType);
+            if (canCastNatively(expr.getTarget())) {
+                if (!canInsertCast) {
+                    return;
+                }
+                block.getBody().add(new WasmCastBranch(WasmCastCondition.SUCCESS, result, sourceType,
+                        targetType, block));
+                result = block;
+            } else {
+                var nonNullValue = new WasmNullBranch(WasmNullCondition.NULL, result, block);
+                nonNullValue.setResult(new WasmNullConstant(sourceType));
+                var valueToCast = exprCache.create(nonNullValue, sourceType, expr.getLocation(), block.getBody());
+
+                var supertypeCall = generateInstanceOf(valueToCast.expr(), expr.getTarget());
+                var breakIfPassed = new WasmBranch(supertypeCall, block);
+                breakIfPassed.setResult(valueToCast.expr());
+                block.getBody().add(new WasmDrop(breakIfPassed));
+
+                result = block;
+                if (canInsertCast) {
+                    var cast = new WasmCast(result, targetType);
+                    cast.setLocation(expr.getLocation());
+                    result = cast;
+                }
+            }
             generateThrowCCE(expr.getLocation(), block.getBody());
-            result = block;
-        } else {
-            super.visit(expr);
+        } else if (canInsertCast) {
+            result = new WasmCast(result, targetType);
+            result.setLocation(expr.getLocation());
         }
     }
 
@@ -478,44 +529,6 @@ public class WasmGCGenerationVisitor extends BaseWasmGenerationVisitor {
         return !cls.hasModifier(ElementModifier.INTERFACE);
     }
 
-    @Override
-    protected WasmExpression generateCast(WasmExpression value, WasmType targetType) {
-        return new WasmCast(value, (WasmType.Reference) targetType);
-    }
-
-    @Override
-    protected WasmType mapCastSourceType(WasmType type) {
-        if (!(type instanceof WasmType.CompositeReference)) {
-            return type;
-        }
-        var refType = (WasmType.CompositeReference) type;
-        return refType.isNullable() ? refType : refType.composite.getReference();
-    }
-
-    @Override
-    protected boolean validateCastTypes(WasmType sourceType, WasmType targetType, TextLocation location) {
-        if (!(sourceType instanceof WasmType.CompositeReference)
-                || !(targetType instanceof WasmType.CompositeReference)) {
-            return false;
-        }
-        var sourceRefType = (WasmType.CompositeReference) sourceType;
-        var targetRefType = (WasmType.CompositeReference) targetType;
-        if (sourceRefType.composite instanceof WasmStructure
-                && targetRefType.composite instanceof WasmStructure) {
-            var sourceStruct = (WasmStructure) sourceRefType.composite;
-            var targetStruct = (WasmStructure) targetRefType.composite;
-            if (targetStruct.isSupertypeOf(sourceStruct)) {
-                return false;
-            }
-            if (!sourceStruct.isSupertypeOf(targetStruct)) {
-                result = new WasmUnreachable();
-                result.setLocation(location);
-                return false;
-            }
-        }
-        return true;
-    }
-
     @Override
     protected boolean needsClassInitializer(String className) {
         return context.classInfoProvider().getClassInfo(className).getInitializerPointer() != null;