wasm gc: trying to fix virtual calls

This commit is contained in:
Alexey Andreev 2024-08-19 15:02:11 +02:00
parent 59259c314d
commit 29f29cea1d
7 changed files with 70 additions and 50 deletions

View File

@ -100,7 +100,7 @@ public class WasmGCClassGenerator implements WasmGCClassInfoProvider, WasmGCInit
private int classTagOffset; private int classTagOffset;
private int classFlagsOffset; private int classFlagsOffset;
private int classNameOffset = -1; private int classNameOffset;
private int classParentOffset; private int classParentOffset;
private int classArrayOffset; private int classArrayOffset;
private int classArrayItemOffset; private int classArrayItemOffset;
@ -162,7 +162,10 @@ public class WasmGCClassGenerator implements WasmGCClassInfoProvider, WasmGCInit
if (classInfo.virtualTableStructure != null && classInfo.getValueType() instanceof ValueType.Object if (classInfo.virtualTableStructure != null && classInfo.getValueType() instanceof ValueType.Object
&& classInfo.hasOwnVirtualTable) { && classInfo.hasOwnVirtualTable) {
var className = ((ValueType.Object) classInfo.getValueType()).getClassName(); var className = ((ValueType.Object) classInfo.getValueType()).getClassName();
classInfo.virtualTableStructure.setSupertype(findVirtualTableSupertype(className)); var candidate = findVirtualTableSupertype(className);
if (candidate != null) {
classInfo.virtualTableStructure.setSupertype(candidate);
}
} }
} }
} }
@ -235,20 +238,28 @@ public class WasmGCClassGenerator implements WasmGCClassInfoProvider, WasmGCInit
var name = type instanceof ValueType.Object var name = type instanceof ValueType.Object
? ((ValueType.Object) type).getClassName() ? ((ValueType.Object) type).getClassName()
: null; : null;
classInfo.structure = new WasmStructure(name != null ? names.forClass(name) : null); var isInterface = false;
var classReader = name != null ? classSource.get(name) : null;
if (classReader != null && classReader.hasModifier(ElementModifier.INTERFACE)) {
isInterface = true;
classInfo.structure = standardClasses.objectClass().structure;
} else {
classInfo.structure = new WasmStructure(name != null ? names.forClass(name) : null);
module.types.add(classInfo.structure);
}
if (name != null) { if (name != null) {
var classReader = classSource.get(name); if (!isInterface) {
if (classReader == null || !classReader.hasModifier(ElementModifier.INTERFACE)) {
virtualTable = virtualTables.lookup(name); virtualTable = virtualTables.lookup(name);
} }
if (classReader != null && classReader.getParent() != null) { if (classReader != null && classReader.getParent() != null && !isInterface) {
classInfo.structure.setSupertype(getClassInfo(classReader.getParent()).structure); classInfo.structure.setSupertype(getClassInfo(classReader.getParent()).structure);
} }
} else { } else {
classInfo.structure.setSupertype(standardClasses.objectClass().structure); classInfo.structure.setSupertype(standardClasses.objectClass().structure);
} }
module.types.add(classInfo.structure); if (!isInterface) {
fillFields(classInfo, type); fillFields(classInfo, type);
}
} }
var pointerName = names.forClassInstance(type); var pointerName = names.forClassInstance(type);
classInfo.hasOwnVirtualTable = virtualTable != null && virtualTable.hasValidEntries(); classInfo.hasOwnVirtualTable = virtualTable != null && virtualTable.hasValidEntries();
@ -405,10 +416,7 @@ public class WasmGCClassGenerator implements WasmGCClassInfoProvider, WasmGCInit
var function = functionProvider.forInstanceMethod(entry.getImplementor()); var function = functionProvider.forInstanceMethod(entry.getImplementor());
if (!virtualTable.getClassName().equals(entry.getImplementor().getClassName()) if (!virtualTable.getClassName().equals(entry.getImplementor().getClassName())
|| expectedFunctionType != function.getType()) { || expectedFunctionType != function.getType()) {
var functionType = typeMapper.getFunctionType(virtualTable.getClassName(), method, true); var wrapperFunction = new WasmFunction(expectedFunctionType);
functionType.getSupertypes().add(expectedFunctionType);
expectedFunctionType.setFinal(false);
var wrapperFunction = new WasmFunction(functionType);
module.functions.add(wrapperFunction); module.functions.add(wrapperFunction);
var call = new WasmCall(function); var call = new WasmCall(function);
var instanceParam = new WasmLocal(getClassInfo(virtualTable.getClassName()).getType()); var instanceParam = new WasmLocal(getClassInfo(virtualTable.getClassName()).getType());
@ -451,7 +459,8 @@ public class WasmGCClassGenerator implements WasmGCClassInfoProvider, WasmGCInit
if (methodDesc == null) { if (methodDesc == null) {
structure.getFields().add(WasmType.Reference.FUNC.asStorage()); structure.getFields().add(WasmType.Reference.FUNC.asStorage());
} else { } else {
var functionType = typeMapper.getFunctionType(virtualTable.getClassName(), methodDesc, false); var originalVirtualTable = virtualTable.findMethodContainer(methodDesc);
var functionType = typeMapper.getFunctionType(originalVirtualTable.getClassName(), methodDesc, false);
structure.getFields().add(functionType.getReference().asStorage()); structure.getFields().add(functionType.getReference().asStorage());
} }
} }
@ -570,8 +579,9 @@ public class WasmGCClassGenerator implements WasmGCClassInfoProvider, WasmGCInit
fields.add(supertypeGenerator.getFunctionType().getReference().asStorage()); fields.add(supertypeGenerator.getFunctionType().getReference().asStorage());
classNewArrayOffset = fields.size(); classNewArrayOffset = fields.size();
fields.add(newArrayGenerator.getNewArrayFunctionType().getReference().asStorage()); fields.add(newArrayGenerator.getNewArrayFunctionType().getReference().asStorage());
classNameOffset = fields.size();
fields.add(standardClasses.stringClass().getType().asStorage());
virtualTableFieldOffset = fields.size(); virtualTableFieldOffset = fields.size();
classNameOffset = fieldIndexes.getOrDefault(new FieldReference(className, "name"), -1);
} }
} }
@ -654,14 +664,12 @@ public class WasmGCClassGenerator implements WasmGCClassInfoProvider, WasmGCInit
classFlagsOffset, classFlagsOffset,
flagsExpr flagsExpr
)); ));
if (classNameOffset >= 0) { function.getBody().add(new WasmStructSet(
function.getBody().add(new WasmStructSet( standardClasses.classClass().getStructure(),
standardClasses.classClass().getStructure(), new WasmGetLocal(targetVar),
new WasmGetLocal(targetVar), classNameOffset,
classNameOffset, new WasmGetLocal(nameVar)
new WasmGetLocal(nameVar) ));
));
}
function.getBody().add(new WasmStructSet( function.getBody().add(new WasmStructSet(
standardClasses.classClass().getStructure(), standardClasses.classClass().getStructure(),
new WasmGetLocal(targetVar), new WasmGetLocal(targetVar),

View File

@ -260,42 +260,41 @@ public class WasmGCGenerationVisitor extends BaseWasmGenerationVisitor {
List<WasmExpression> arguments) { List<WasmExpression> arguments) {
var vtable = context.virtualTables().lookup(method.getClassName()); var vtable = context.virtualTables().lookup(method.getClassName());
if (vtable != null) { if (vtable != null) {
var cls = context.classes().get(method.getClassName());
assert cls != null : "Virtual table can't be generated for absent class";
if (cls.hasModifier(ElementModifier.INTERFACE)) {
vtable = pickVirtualTableForInterfaceCall(vtable, method.getDescriptor());
}
vtable = vtable.findMethodContainer(method.getDescriptor()); vtable = vtable.findMethodContainer(method.getDescriptor());
} }
if (vtable == null) { if (vtable == null) {
return new WasmUnreachable(); return new WasmUnreachable();
} }
var cls = context.classes().get(method.getClassName());
assert cls != null : "Virtual table can't be generated for absent class";
if (cls.hasModifier(ElementModifier.INTERFACE)) {
vtable = pickVirtualTableForInterfaceCall(vtable, method.getDescriptor());
}
int vtableIndex = vtable.getMethods().indexOf(method.getDescriptor()); int vtableIndex = vtable.getMethods().indexOf(method.getDescriptor());
if (vtable.getParent() != null) { if (vtable.getParent() != null) {
vtableIndex += vtable.getParent().size(); vtableIndex += vtable.getParent().size();
} }
var instanceStruct = context.classInfoProvider().getClassInfo(vtable.getClassName()).getStructure();
var actualInstanceType = (WasmType.CompositeReference) instance.getType(); WasmExpression classRef = new WasmStructGet(context.standardClasses().objectClass().getStructure(),
var actualInstanceStruct = (WasmStructure) actualInstanceType.composite; new WasmGetLocal(instance), WasmGCClassInfoProvider.CLASS_FIELD_OFFSET);
var actualVtableType = (WasmType.CompositeReference) actualInstanceStruct.getFields().get(0).asUnpackedType();
var actualVtableStruct = (WasmStructure) actualVtableType.composite;
WasmExpression classRef = new WasmStructGet(instanceStruct, new WasmGetLocal(instance),
WasmGCClassInfoProvider.CLASS_FIELD_OFFSET);
var index = context.classInfoProvider().getVirtualMethodsOffset() + vtableIndex; var index = context.classInfoProvider().getVirtualMethodsOffset() + vtableIndex;
var vtableStruct = context.classInfoProvider().getClassInfo(vtable.getClassName()) var expectedInstanceClassInfo = context.classInfoProvider().getClassInfo(vtable.getClassName());
.getVirtualTableStructure(); var vtableStruct = expectedInstanceClassInfo.getVirtualTableStructure();
if (!vtableStruct.isSupertypeOf(actualVtableStruct)) { classRef = new WasmCast(classRef, vtableStruct.getReference());
classRef = new WasmCast(classRef, vtableStruct.getReference());
}
var functionRef = new WasmStructGet(vtableStruct, classRef, index); var functionRef = new WasmStructGet(vtableStruct, classRef, index);
var functionTypeRef = (WasmType.CompositeReference) vtableStruct.getFields().get(index).asUnpackedType(); var functionTypeRef = (WasmType.CompositeReference) vtableStruct.getFields().get(index).asUnpackedType();
var invoke = new WasmCallReference(functionRef, (WasmFunctionType) functionTypeRef.composite); var invoke = new WasmCallReference(functionRef, (WasmFunctionType) functionTypeRef.composite);
invoke.getArguments().addAll(arguments); WasmExpression instanceRef = new WasmGetLocal(instance);
var instanceType = (WasmType.CompositeReference) instance.getType();
var instanceStruct = (WasmStructure) instanceType.composite;
if (!expectedInstanceClassInfo.getStructure().isSupertypeOf(instanceStruct)) {
instanceRef = new WasmCast(instanceRef, expectedInstanceClassInfo.getType());
}
invoke.getArguments().add(instanceRef);
invoke.getArguments().addAll(arguments.subList(1, arguments.size()));
return invoke; return invoke;
} }

View File

@ -43,6 +43,9 @@ public class WasmCollection<T extends WasmEntity> implements Iterable<T> {
} }
public void add(T entity) { public void add(T entity) {
if (entity.collection != null) {
throw new IllegalArgumentException("Entity already belongs some collection");
}
if (!indexesInvalid) { if (!indexesInvalid) {
entity.index = items.size(); entity.index = items.size();
} }

View File

@ -98,7 +98,7 @@ public class WasmModule {
} }
private void prepareTypes() { private void prepareTypes() {
var typeGraph = WasmTypeGraphBuilder.buildTypeGraph(types, types.size()); var typeGraph = WasmTypeGraphBuilder.buildTypeGraph(this, types, types.size());
var sccs = GraphUtils.findStronglyConnectedComponents(typeGraph); var sccs = GraphUtils.findStronglyConnectedComponents(typeGraph);
var sccStartNode = new int[types.size()]; var sccStartNode = new int[types.size()];
for (var i = 0; i < sccStartNode.length; ++i) { for (var i = 0; i < sccStartNode.length; ++i) {

View File

@ -22,21 +22,23 @@ final class WasmTypeGraphBuilder {
private WasmTypeGraphBuilder() { private WasmTypeGraphBuilder() {
} }
static Graph buildTypeGraph(Iterable<WasmCompositeType> types, int size) { static Graph buildTypeGraph(WasmModule module, Iterable<WasmCompositeType> types, int size) {
var graphBuilder = new GraphBuilder(size); var graphBuilder = new GraphBuilder(size);
var visitor = new GraphBuilderVisitor(graphBuilder); var visitor = new GraphBuilderVisitor(module, graphBuilder);
for (var type : types) { for (var type : types) {
visitor.currentIndex = type.index; visitor.currentIndex = module.types.indexOf(type);
type.acceptVisitor(visitor); type.acceptVisitor(visitor);
} }
return graphBuilder.build(); return graphBuilder.build();
} }
private static class GraphBuilderVisitor implements WasmCompositeTypeVisitor { private static class GraphBuilderVisitor implements WasmCompositeTypeVisitor {
final WasmModule module;
final GraphBuilder graphBuilder; final GraphBuilder graphBuilder;
int currentIndex; int currentIndex;
GraphBuilderVisitor(GraphBuilder graphBuilder) { GraphBuilderVisitor(WasmModule module, GraphBuilder graphBuilder) {
this.module = module;
this.graphBuilder = graphBuilder; this.graphBuilder = graphBuilder;
} }
@ -71,7 +73,7 @@ final class WasmTypeGraphBuilder {
private void addEdge(WasmType type) { private void addEdge(WasmType type) {
if (type instanceof WasmType.CompositeReference) { if (type instanceof WasmType.CompositeReference) {
var composite = ((WasmType.CompositeReference) type).composite; var composite = ((WasmType.CompositeReference) type).composite;
graphBuilder.addEdge(currentIndex, composite.index); graphBuilder.addEdge(currentIndex, module.types.indexOf(composite));
} }
} }
} }

View File

@ -25,6 +25,7 @@ public class WasmStructGet extends WasmExpression {
private WasmSignedType signedType; private WasmSignedType signedType;
public WasmStructGet(WasmStructure type, WasmExpression instance, int fieldIndex) { public WasmStructGet(WasmStructure type, WasmExpression instance, int fieldIndex) {
checkFieldIndex(fieldIndex);
this.type = Objects.requireNonNull(type); this.type = Objects.requireNonNull(type);
this.instance = Objects.requireNonNull(instance); this.instance = Objects.requireNonNull(instance);
this.fieldIndex = fieldIndex; this.fieldIndex = fieldIndex;
@ -51,6 +52,7 @@ public class WasmStructGet extends WasmExpression {
} }
public void setFieldIndex(int fieldIndex) { public void setFieldIndex(int fieldIndex) {
checkFieldIndex(fieldIndex);
this.fieldIndex = fieldIndex; this.fieldIndex = fieldIndex;
} }
@ -66,4 +68,10 @@ public class WasmStructGet extends WasmExpression {
public void acceptVisitor(WasmExpressionVisitor visitor) { public void acceptVisitor(WasmExpressionVisitor visitor) {
visitor.visit(this); visitor.visit(this);
} }
private static void checkFieldIndex(int fieldIndex) {
if (fieldIndex < 0) {
throw new IllegalArgumentException("Field index must be >= 0");
}
}
} }

View File

@ -316,12 +316,12 @@ public abstract class BaseTypeInference<T> {
@Override @Override
public void visit(ClassConstantInstruction insn) { public void visit(ClassConstantInstruction insn) {
type(insn.getReceiver(), ValueType.object("java/lang/Class")); type(insn.getReceiver(), ValueType.object("java.lang.Class"));
} }
@Override @Override
public void visit(StringConstantInstruction insn) { public void visit(StringConstantInstruction insn) {
type(insn.getReceiver(), ValueType.object("java/lang/String")); type(insn.getReceiver(), ValueType.object("java.lang.String"));
} }
@Override @Override