From 29dec0962b1cbeaf9c5a9fa8fd25d17d8a6571a6 Mon Sep 17 00:00:00 2001 From: Alexey Andreev Date: Wed, 4 Sep 2024 20:58:29 +0200 Subject: [PATCH] wasm gc: fix issues with type inference --- .../gc/WasmGCVariableCategoryProvider.java | 9 +---- .../gc/methods/WasmGCMethodGenerator.java | 6 ++- .../model/analysis/BaseTypeInference.java | 5 +++ .../teavm/model/util/RegisterAllocator.java | 38 +++++++++++++------ 4 files changed, 38 insertions(+), 20 deletions(-) diff --git a/core/src/main/java/org/teavm/backend/wasm/gc/WasmGCVariableCategoryProvider.java b/core/src/main/java/org/teavm/backend/wasm/gc/WasmGCVariableCategoryProvider.java index b77560eae..9f9c16cc4 100644 --- a/core/src/main/java/org/teavm/backend/wasm/gc/WasmGCVariableCategoryProvider.java +++ b/core/src/main/java/org/teavm/backend/wasm/gc/WasmGCVariableCategoryProvider.java @@ -22,20 +22,15 @@ import org.teavm.model.util.VariableCategoryProvider; public class WasmGCVariableCategoryProvider implements VariableCategoryProvider { private ClassHierarchy hierarchy; - private PreciseTypeInference inference; public WasmGCVariableCategoryProvider(ClassHierarchy hierarchy) { this.hierarchy = hierarchy; } - public PreciseTypeInference getTypeInference() { - return inference; - } - @Override public Object[] getCategories(Program program, MethodReference method) { - inference = new PreciseTypeInference(program, method, hierarchy); - inference.setPhisSkipped(true); + var inference = new PreciseTypeInference(program, method, hierarchy); + inference.setPhisSkipped(false); inference.setBackPropagation(true); var result = new Object[program.variableCount()]; for (int i = 0; i < program.variableCount(); ++i) { diff --git a/core/src/main/java/org/teavm/backend/wasm/generate/gc/methods/WasmGCMethodGenerator.java b/core/src/main/java/org/teavm/backend/wasm/generate/gc/methods/WasmGCMethodGenerator.java index 9816f89f9..d127a58bc 100644 --- a/core/src/main/java/org/teavm/backend/wasm/generate/gc/methods/WasmGCMethodGenerator.java +++ b/core/src/main/java/org/teavm/backend/wasm/generate/gc/methods/WasmGCMethodGenerator.java @@ -25,6 +25,7 @@ import java.util.Set; import org.teavm.ast.decompilation.Decompiler; import org.teavm.backend.wasm.BaseWasmFunctionRepository; import org.teavm.backend.wasm.WasmFunctionTypes; +import org.teavm.backend.wasm.gc.PreciseTypeInference; import org.teavm.backend.wasm.gc.WasmGCVariableCategoryProvider; import org.teavm.backend.wasm.gc.vtable.WasmGCVirtualTableProvider; import org.teavm.backend.wasm.generate.gc.WasmGCNameProvider; @@ -220,7 +221,10 @@ public class WasmGCMethodGenerator implements BaseWasmFunctionRepository { allocator.allocateRegisters(method.getReference(), method.getProgram(), friendlyToDebugger); var ast = decompiler.decompileRegular(method); var firstVar = method.hasModifier(ElementModifier.STATIC) ? 1 : 0; - var typeInference = categoryProvider.getTypeInference(); + var typeInference = new PreciseTypeInference(method.getProgram(), method.getReference(), hierarchy); + typeInference.setPhisSkipped(true); + typeInference.setBackPropagation(true); + typeInference.ensure(); var registerCount = 0; for (var i = 0; i < method.getProgram().variableCount(); ++i) { diff --git a/core/src/main/java/org/teavm/model/analysis/BaseTypeInference.java b/core/src/main/java/org/teavm/model/analysis/BaseTypeInference.java index 538219e40..4aed8c685 100644 --- a/core/src/main/java/org/teavm/model/analysis/BaseTypeInference.java +++ b/core/src/main/java/org/teavm/model/analysis/BaseTypeInference.java @@ -547,6 +547,11 @@ public abstract class BaseTypeInference { push(insn.getValue(), insn.getFieldType()); } + @Override + public void visit(CastInstruction insn) { + push(insn.getValue(), insn.getTargetType()); + } + private void push(Variable variable, ValueType type) { if (nullTypes[variable.getIndex()]) { stack.push(variable.getIndex()); diff --git a/core/src/main/java/org/teavm/model/util/RegisterAllocator.java b/core/src/main/java/org/teavm/model/util/RegisterAllocator.java index 0fe85ac5e..c761b217b 100644 --- a/core/src/main/java/org/teavm/model/util/RegisterAllocator.java +++ b/core/src/main/java/org/teavm/model/util/RegisterAllocator.java @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Objects; import org.teavm.common.DisjointSet; import org.teavm.common.MutableGraphEdge; import org.teavm.common.MutableGraphNode; @@ -45,7 +46,10 @@ public class RegisterAllocator { } public void allocateRegisters(MethodReference method, Program program, boolean debuggerFriendly) { - insertPhiArgumentsCopies(program); + var categories = variableCategoryProvider.getCategories(program, method); + var categoryList = new ArrayList<>(Arrays.asList(categories)); + insertPhiArgumentsCopies(program, categoryList); + categories = categoryList.toArray(); InterferenceGraphBuilder interferenceBuilder = new InterferenceGraphBuilder(); LivenessAnalyzer liveness = new LivenessAnalyzer(); liveness.analyze(program, method.getDescriptor()); @@ -53,7 +57,7 @@ public class RegisterAllocator { program, method.parameterCount(), liveness); DisjointSet congruenceClasses = buildPhiCongruenceClasses(program); joinClassNodes(interferenceGraph, congruenceClasses); - removeRedundantCopies(program, interferenceGraph, congruenceClasses); + removeRedundantCopies(program, interferenceGraph, congruenceClasses, categories); int[] classArray = congruenceClasses.pack(program.variableCount()); renameVariables(program, classArray); int[] colors = new int[program.variableCount()]; @@ -68,8 +72,13 @@ public class RegisterAllocator { for (int cls : classArray) { maxClass = Math.max(maxClass, cls + 1); } - var categories = variableCategoryProvider.getCategories(program, method); String[] names = getVariableNames(program, debuggerFriendly); + var newCategories = new Object[categories.length]; + for (int i = 0; i < categories.length; ++i) { + var cls = classArray[i]; + newCategories[cls] = categories[i]; + } + categories = newCategories; colorer.colorize(MutableGraphNode.toGraph(interferenceGraph), colors, categories, names); int maxColor = 0; @@ -129,7 +138,7 @@ public class RegisterAllocator { } } - private void insertPhiArgumentsCopies(Program program) { + private void insertPhiArgumentsCopies(Program program, List categories) { List> catchIncomingsByVariable = new ArrayList<>( Collections.nCopies(program.variableCount(), null)); @@ -151,14 +160,14 @@ public class RegisterAllocator { } catchIncomings.add(incoming); } else { - insertCopy(incoming, blockMap); + insertCopy(incoming, blockMap, categories); incomingsToRepeat.add(incoming); } } } for (Incoming incoming : incomingsToRepeat) { - insertCopy(incoming, blockMap); + insertCopy(incoming, blockMap, categories); } } @@ -167,7 +176,7 @@ public class RegisterAllocator { for (BasicBlock block : program.getBasicBlocks()) { for (Phi phi : block.getPhis()) { addExceptionHandlingCopies(catchIncomingsByVariable, phi.getReceiver(), block, - program, block.getFirstInstruction().getLocation(), nextInstructions); + program, block.getFirstInstruction().getLocation(), nextInstructions, categories); } if (!nextInstructions.isEmpty()) { @@ -180,7 +189,7 @@ public class RegisterAllocator { Variable[] definedVariables = definitionExtractor.getDefinedVariables(); for (Variable definedVariable : definedVariables) { addExceptionHandlingCopies(catchIncomingsByVariable, definedVariable, block, - program, instruction.getLocation(), nextInstructions); + program, instruction.getLocation(), nextInstructions, categories); } if (!nextInstructions.isEmpty()) { @@ -198,6 +207,7 @@ public class RegisterAllocator { BasicBlock block = incoming.getSource(); Variable copy = program.createVariable(); + categories.add(categories.get(incoming.getPhi().getReceiver().getIndex())); copy.setLabel(incoming.getPhi().getReceiver().getLabel()); copy.setDebugName(incoming.getPhi().getReceiver().getDebugName()); @@ -213,7 +223,8 @@ public class RegisterAllocator { } private void addExceptionHandlingCopies(List> catchIncomingsByVariable, Variable definedVariable, - BasicBlock block, Program program, TextLocation location, List nextInstructions) { + BasicBlock block, Program program, TextLocation location, List nextInstructions, + List categories) { if (definedVariable.getIndex() >= catchIncomingsByVariable.size()) { return; } @@ -228,6 +239,7 @@ public class RegisterAllocator { Variable copy = program.createVariable(); copy.setLabel(incoming.getPhi().getReceiver().getLabel()); copy.setDebugName(incoming.getPhi().getReceiver().getDebugName()); + categories.add(categories.get(incoming.getPhi().getReceiver().getIndex())); AssignInstruction copyInstruction = new AssignInstruction(); copyInstruction.setReceiver(copy); @@ -242,11 +254,12 @@ public class RegisterAllocator { } } - private void insertCopy(Incoming incoming, Map blockMap) { + private void insertCopy(Incoming incoming, Map blockMap, List categories) { Phi phi = incoming.getPhi(); Program program = phi.getBasicBlock().getProgram(); AssignInstruction copyInstruction = new AssignInstruction(); Variable firstCopy = program.createVariable(); + categories.add(categories.get(phi.getReceiver().getIndex())); firstCopy.setLabel(phi.getReceiver().getLabel()); firstCopy.setDebugName(phi.getReceiver().getDebugName()); copyInstruction.setReceiver(firstCopy); @@ -274,7 +287,7 @@ public class RegisterAllocator { } private void removeRedundantCopies(Program program, List interferenceGraph, - DisjointSet congruenceClasses) { + DisjointSet congruenceClasses, Object[] categories) { for (int i = 0; i < program.basicBlockCount(); ++i) { BasicBlock block = program.basicBlockAt(i); Instruction nextInsn; @@ -297,7 +310,8 @@ public class RegisterAllocator { break; } } - if (!interfere) { + if (!interfere && Objects.equals(categories[assignment.getReceiver().getIndex()], + categories[assignment.getAssignee().getIndex()])) { int newClass = congruenceClasses.union(copyClass, origClass); insn.delete(); if (newClass == interferenceGraph.size()) {