wasm gc: trying to fix virtual calls

This commit is contained in:
Alexey Andreev 2024-08-19 15:02:11 +02:00
parent 59259c314d
commit 29f29cea1d
7 changed files with 70 additions and 50 deletions

View File

@ -100,7 +100,7 @@ public class WasmGCClassGenerator implements WasmGCClassInfoProvider, WasmGCInit
private int classTagOffset;
private int classFlagsOffset;
private int classNameOffset = -1;
private int classNameOffset;
private int classParentOffset;
private int classArrayOffset;
private int classArrayItemOffset;
@ -162,7 +162,10 @@ public class WasmGCClassGenerator implements WasmGCClassInfoProvider, WasmGCInit
if (classInfo.virtualTableStructure != null && classInfo.getValueType() instanceof ValueType.Object
&& classInfo.hasOwnVirtualTable) {
var className = ((ValueType.Object) classInfo.getValueType()).getClassName();
classInfo.virtualTableStructure.setSupertype(findVirtualTableSupertype(className));
var candidate = findVirtualTableSupertype(className);
if (candidate != null) {
classInfo.virtualTableStructure.setSupertype(candidate);
}
}
}
}
@ -235,21 +238,29 @@ public class WasmGCClassGenerator implements WasmGCClassInfoProvider, WasmGCInit
var name = type instanceof ValueType.Object
? ((ValueType.Object) type).getClassName()
: null;
var isInterface = false;
var classReader = name != null ? classSource.get(name) : null;
if (classReader != null && classReader.hasModifier(ElementModifier.INTERFACE)) {
isInterface = true;
classInfo.structure = standardClasses.objectClass().structure;
} else {
classInfo.structure = new WasmStructure(name != null ? names.forClass(name) : null);
module.types.add(classInfo.structure);
}
if (name != null) {
var classReader = classSource.get(name);
if (classReader == null || !classReader.hasModifier(ElementModifier.INTERFACE)) {
if (!isInterface) {
virtualTable = virtualTables.lookup(name);
}
if (classReader != null && classReader.getParent() != null) {
if (classReader != null && classReader.getParent() != null && !isInterface) {
classInfo.structure.setSupertype(getClassInfo(classReader.getParent()).structure);
}
} else {
classInfo.structure.setSupertype(standardClasses.objectClass().structure);
}
module.types.add(classInfo.structure);
if (!isInterface) {
fillFields(classInfo, type);
}
}
var pointerName = names.forClassInstance(type);
classInfo.hasOwnVirtualTable = virtualTable != null && virtualTable.hasValidEntries();
var classStructure = classInfo.hasOwnVirtualTable
@ -405,10 +416,7 @@ public class WasmGCClassGenerator implements WasmGCClassInfoProvider, WasmGCInit
var function = functionProvider.forInstanceMethod(entry.getImplementor());
if (!virtualTable.getClassName().equals(entry.getImplementor().getClassName())
|| expectedFunctionType != function.getType()) {
var functionType = typeMapper.getFunctionType(virtualTable.getClassName(), method, true);
functionType.getSupertypes().add(expectedFunctionType);
expectedFunctionType.setFinal(false);
var wrapperFunction = new WasmFunction(functionType);
var wrapperFunction = new WasmFunction(expectedFunctionType);
module.functions.add(wrapperFunction);
var call = new WasmCall(function);
var instanceParam = new WasmLocal(getClassInfo(virtualTable.getClassName()).getType());
@ -451,7 +459,8 @@ public class WasmGCClassGenerator implements WasmGCClassInfoProvider, WasmGCInit
if (methodDesc == null) {
structure.getFields().add(WasmType.Reference.FUNC.asStorage());
} else {
var functionType = typeMapper.getFunctionType(virtualTable.getClassName(), methodDesc, false);
var originalVirtualTable = virtualTable.findMethodContainer(methodDesc);
var functionType = typeMapper.getFunctionType(originalVirtualTable.getClassName(), methodDesc, false);
structure.getFields().add(functionType.getReference().asStorage());
}
}
@ -570,8 +579,9 @@ public class WasmGCClassGenerator implements WasmGCClassInfoProvider, WasmGCInit
fields.add(supertypeGenerator.getFunctionType().getReference().asStorage());
classNewArrayOffset = fields.size();
fields.add(newArrayGenerator.getNewArrayFunctionType().getReference().asStorage());
classNameOffset = fields.size();
fields.add(standardClasses.stringClass().getType().asStorage());
virtualTableFieldOffset = fields.size();
classNameOffset = fieldIndexes.getOrDefault(new FieldReference(className, "name"), -1);
}
}
@ -654,14 +664,12 @@ public class WasmGCClassGenerator implements WasmGCClassInfoProvider, WasmGCInit
classFlagsOffset,
flagsExpr
));
if (classNameOffset >= 0) {
function.getBody().add(new WasmStructSet(
standardClasses.classClass().getStructure(),
new WasmGetLocal(targetVar),
classNameOffset,
new WasmGetLocal(nameVar)
));
}
function.getBody().add(new WasmStructSet(
standardClasses.classClass().getStructure(),
new WasmGetLocal(targetVar),

View File

@ -260,42 +260,41 @@ public class WasmGCGenerationVisitor extends BaseWasmGenerationVisitor {
List<WasmExpression> arguments) {
var vtable = context.virtualTables().lookup(method.getClassName());
if (vtable != null) {
var cls = context.classes().get(method.getClassName());
assert cls != null : "Virtual table can't be generated for absent class";
if (cls.hasModifier(ElementModifier.INTERFACE)) {
vtable = pickVirtualTableForInterfaceCall(vtable, method.getDescriptor());
}
vtable = vtable.findMethodContainer(method.getDescriptor());
}
if (vtable == null) {
return new WasmUnreachable();
}
var cls = context.classes().get(method.getClassName());
assert cls != null : "Virtual table can't be generated for absent class";
if (cls.hasModifier(ElementModifier.INTERFACE)) {
vtable = pickVirtualTableForInterfaceCall(vtable, method.getDescriptor());
}
int vtableIndex = vtable.getMethods().indexOf(method.getDescriptor());
if (vtable.getParent() != null) {
vtableIndex += vtable.getParent().size();
}
var instanceStruct = context.classInfoProvider().getClassInfo(vtable.getClassName()).getStructure();
var actualInstanceType = (WasmType.CompositeReference) instance.getType();
var actualInstanceStruct = (WasmStructure) actualInstanceType.composite;
var actualVtableType = (WasmType.CompositeReference) actualInstanceStruct.getFields().get(0).asUnpackedType();
var actualVtableStruct = (WasmStructure) actualVtableType.composite;
WasmExpression classRef = new WasmStructGet(instanceStruct, new WasmGetLocal(instance),
WasmGCClassInfoProvider.CLASS_FIELD_OFFSET);
WasmExpression classRef = new WasmStructGet(context.standardClasses().objectClass().getStructure(),
new WasmGetLocal(instance), WasmGCClassInfoProvider.CLASS_FIELD_OFFSET);
var index = context.classInfoProvider().getVirtualMethodsOffset() + vtableIndex;
var vtableStruct = context.classInfoProvider().getClassInfo(vtable.getClassName())
.getVirtualTableStructure();
if (!vtableStruct.isSupertypeOf(actualVtableStruct)) {
var expectedInstanceClassInfo = context.classInfoProvider().getClassInfo(vtable.getClassName());
var vtableStruct = expectedInstanceClassInfo.getVirtualTableStructure();
classRef = new WasmCast(classRef, vtableStruct.getReference());
}
var functionRef = new WasmStructGet(vtableStruct, classRef, index);
var functionTypeRef = (WasmType.CompositeReference) vtableStruct.getFields().get(index).asUnpackedType();
var invoke = new WasmCallReference(functionRef, (WasmFunctionType) functionTypeRef.composite);
invoke.getArguments().addAll(arguments);
WasmExpression instanceRef = new WasmGetLocal(instance);
var instanceType = (WasmType.CompositeReference) instance.getType();
var instanceStruct = (WasmStructure) instanceType.composite;
if (!expectedInstanceClassInfo.getStructure().isSupertypeOf(instanceStruct)) {
instanceRef = new WasmCast(instanceRef, expectedInstanceClassInfo.getType());
}
invoke.getArguments().add(instanceRef);
invoke.getArguments().addAll(arguments.subList(1, arguments.size()));
return invoke;
}

View File

@ -43,6 +43,9 @@ public class WasmCollection<T extends WasmEntity> implements Iterable<T> {
}
public void add(T entity) {
if (entity.collection != null) {
throw new IllegalArgumentException("Entity already belongs some collection");
}
if (!indexesInvalid) {
entity.index = items.size();
}

View File

@ -98,7 +98,7 @@ public class WasmModule {
}
private void prepareTypes() {
var typeGraph = WasmTypeGraphBuilder.buildTypeGraph(types, types.size());
var typeGraph = WasmTypeGraphBuilder.buildTypeGraph(this, types, types.size());
var sccs = GraphUtils.findStronglyConnectedComponents(typeGraph);
var sccStartNode = new int[types.size()];
for (var i = 0; i < sccStartNode.length; ++i) {

View File

@ -22,21 +22,23 @@ final class WasmTypeGraphBuilder {
private WasmTypeGraphBuilder() {
}
static Graph buildTypeGraph(Iterable<WasmCompositeType> types, int size) {
static Graph buildTypeGraph(WasmModule module, Iterable<WasmCompositeType> types, int size) {
var graphBuilder = new GraphBuilder(size);
var visitor = new GraphBuilderVisitor(graphBuilder);
var visitor = new GraphBuilderVisitor(module, graphBuilder);
for (var type : types) {
visitor.currentIndex = type.index;
visitor.currentIndex = module.types.indexOf(type);
type.acceptVisitor(visitor);
}
return graphBuilder.build();
}
private static class GraphBuilderVisitor implements WasmCompositeTypeVisitor {
final WasmModule module;
final GraphBuilder graphBuilder;
int currentIndex;
GraphBuilderVisitor(GraphBuilder graphBuilder) {
GraphBuilderVisitor(WasmModule module, GraphBuilder graphBuilder) {
this.module = module;
this.graphBuilder = graphBuilder;
}
@ -71,7 +73,7 @@ final class WasmTypeGraphBuilder {
private void addEdge(WasmType type) {
if (type instanceof WasmType.CompositeReference) {
var composite = ((WasmType.CompositeReference) type).composite;
graphBuilder.addEdge(currentIndex, composite.index);
graphBuilder.addEdge(currentIndex, module.types.indexOf(composite));
}
}
}

View File

@ -25,6 +25,7 @@ public class WasmStructGet extends WasmExpression {
private WasmSignedType signedType;
public WasmStructGet(WasmStructure type, WasmExpression instance, int fieldIndex) {
checkFieldIndex(fieldIndex);
this.type = Objects.requireNonNull(type);
this.instance = Objects.requireNonNull(instance);
this.fieldIndex = fieldIndex;
@ -51,6 +52,7 @@ public class WasmStructGet extends WasmExpression {
}
public void setFieldIndex(int fieldIndex) {
checkFieldIndex(fieldIndex);
this.fieldIndex = fieldIndex;
}
@ -66,4 +68,10 @@ public class WasmStructGet extends WasmExpression {
public void acceptVisitor(WasmExpressionVisitor visitor) {
visitor.visit(this);
}
private static void checkFieldIndex(int fieldIndex) {
if (fieldIndex < 0) {
throw new IllegalArgumentException("Field index must be >= 0");
}
}
}

View File

@ -316,12 +316,12 @@ public abstract class BaseTypeInference<T> {
@Override
public void visit(ClassConstantInstruction insn) {
type(insn.getReceiver(), ValueType.object("java/lang/Class"));
type(insn.getReceiver(), ValueType.object("java.lang.Class"));
}
@Override
public void visit(StringConstantInstruction insn) {
type(insn.getReceiver(), ValueType.object("java/lang/String"));
type(insn.getReceiver(), ValueType.object("java.lang.String"));
}
@Override