Further work on WASM tree -> C renderer

This commit is contained in:
Alexey Andreev 2016-08-23 23:22:29 +03:00
parent 1fb929e9ae
commit 199d91d28c
2 changed files with 411 additions and 22 deletions

View File

@ -20,6 +20,7 @@ import java.util.List;
class CExpression { class CExpression {
private String text; private String text;
private boolean relocatable;
private List<CLine> lines = new ArrayList<>(); private List<CLine> lines = new ArrayList<>();
public CExpression(String text) { public CExpression(String text) {
@ -29,6 +30,14 @@ class CExpression {
public CExpression() { public CExpression() {
} }
public boolean isRelocatable() {
return relocatable;
}
public void setRelocatable(boolean relocatable) {
this.relocatable = relocatable;
}
public String getText() { public String getText() {
return text; return text;
} }
@ -40,4 +49,10 @@ class CExpression {
public List<CLine> getLines() { public List<CLine> getLines() {
return lines; return lines;
} }
public static CExpression relocatable(String text) {
CExpression expression = new CExpression(text);
expression.setRelocatable(true);
return expression;
}
} }

View File

@ -16,9 +16,12 @@
package org.teavm.backend.wasm.render; package org.teavm.backend.wasm.render;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import org.teavm.backend.wasm.model.WasmFunction;
import org.teavm.backend.wasm.model.WasmModule;
import org.teavm.backend.wasm.model.WasmType; import org.teavm.backend.wasm.model.WasmType;
import org.teavm.backend.wasm.model.expression.WasmBlock; import org.teavm.backend.wasm.model.expression.WasmBlock;
import org.teavm.backend.wasm.model.expression.WasmBranch; import org.teavm.backend.wasm.model.expression.WasmBranch;
@ -32,6 +35,7 @@ import org.teavm.backend.wasm.model.expression.WasmExpressionVisitor;
import org.teavm.backend.wasm.model.expression.WasmFloat32Constant; import org.teavm.backend.wasm.model.expression.WasmFloat32Constant;
import org.teavm.backend.wasm.model.expression.WasmFloat64Constant; import org.teavm.backend.wasm.model.expression.WasmFloat64Constant;
import org.teavm.backend.wasm.model.expression.WasmFloatBinary; import org.teavm.backend.wasm.model.expression.WasmFloatBinary;
import org.teavm.backend.wasm.model.expression.WasmFloatType;
import org.teavm.backend.wasm.model.expression.WasmFloatUnary; import org.teavm.backend.wasm.model.expression.WasmFloatUnary;
import org.teavm.backend.wasm.model.expression.WasmGetLocal; import org.teavm.backend.wasm.model.expression.WasmGetLocal;
import org.teavm.backend.wasm.model.expression.WasmIndirectCall; import org.teavm.backend.wasm.model.expression.WasmIndirectCall;
@ -61,9 +65,11 @@ class WasmCRenderingVisitor implements WasmExpressionVisitor {
private int temporaryIndex; private int temporaryIndex;
private int blockIndex; private int blockIndex;
private WasmType functionType; private WasmType functionType;
private WasmModule module;
public WasmCRenderingVisitor(WasmType functionType) { public WasmCRenderingVisitor(WasmType functionType, WasmModule module) {
this.functionType = functionType; this.functionType = functionType;
this.module = module;
} }
public CExpression getValue() { public CExpression getValue() {
@ -281,27 +287,27 @@ class WasmCRenderingVisitor implements WasmExpressionVisitor {
@Override @Override
public void visit(WasmInt32Constant expression) { public void visit(WasmInt32Constant expression) {
value = new CExpression("INT32_C(" + String.valueOf(expression.getValue()) + ")"); value = CExpression.relocatable("INT32_C(" + String.valueOf(expression.getValue()) + ")");
} }
@Override @Override
public void visit(WasmInt64Constant expression) { public void visit(WasmInt64Constant expression) {
value = new CExpression("INT64_C(" + String.valueOf(expression.getValue()) + ")"); value = CExpression.relocatable("INT64_C(" + String.valueOf(expression.getValue()) + ")");
} }
@Override @Override
public void visit(WasmFloat32Constant expression) { public void visit(WasmFloat32Constant expression) {
value = new CExpression(Float.toHexString(expression.getValue()) + "F"); value = CExpression.relocatable(Float.toHexString(expression.getValue()) + "F");
} }
@Override @Override
public void visit(WasmFloat64Constant expression) { public void visit(WasmFloat64Constant expression) {
value = new CExpression(Double.toHexString(expression.getValue())); value = CExpression.relocatable(Double.toHexString(expression.getValue()));
} }
@Override @Override
public void visit(WasmGetLocal expression) { public void visit(WasmGetLocal expression) {
value = new CExpression("var_" + expression.getLocal().getIndex()); value = CExpression.relocatable("var_" + expression.getLocal().getIndex());
} }
@Override @Override
@ -334,19 +340,18 @@ class WasmCRenderingVisitor implements WasmExpressionVisitor {
result.getLines().addAll(first.getLines()); result.getLines().addAll(first.getLines());
result.getLines().addAll(second.getLines()); result.getLines().addAll(second.getLines());
} else { } else {
String firstOp = first.getText();
String secondOp = second.getText();
result.getLines().addAll(first.getLines()); result.getLines().addAll(first.getLines());
if (!second.getLines().isEmpty()) { if (!second.getLines().isEmpty()) {
firstOp = "tmp_" + temporaryIndex; first = cacheIfNeeded(opType, first, result);
result.getLines().add(new CSingleLine(mapType(opType) + " " + firstOp + " = "
+ first.getText() + ";"));
result.getLines().addAll(second.getLines()); result.getLines().addAll(second.getLines());
} }
String unsingedType = "u" + mapType(opType); String firstOp = first.getText();
String firstOpUnsinged = "(" + unsingedType + ") " + firstOp; String secondOp = second.getText();
String secondOpUnsigned = "(" + unsingedType + ") " + secondOp; String typeText = mapType(opType);
String unsignedType = "u" + typeText;
String firstOpUnsigned = "(" + unsignedType + ") " + firstOp;
String secondOpUnsigned = "(" + unsignedType + ") " + secondOp;
switch (expression.getOperation()) { switch (expression.getOperation()) {
case ADD: case ADD:
result.setText("(" + firstOp + " + " + secondOp + ")"); result.setText("(" + firstOp + " + " + secondOp + ")");
@ -361,13 +366,13 @@ class WasmCRenderingVisitor implements WasmExpressionVisitor {
result.setText("(" + firstOp + " / " + secondOp + ")"); result.setText("(" + firstOp + " / " + secondOp + ")");
break; break;
case DIV_UNSIGNED: case DIV_UNSIGNED:
result.setText("(" + firstOpUnsinged + " / " + secondOpUnsigned + ")"); result.setText("(" + typeText + ") (" + firstOpUnsigned + " / " + secondOpUnsigned + ")");
break; break;
case REM_SIGNED: case REM_SIGNED:
result.setText("(" + firstOp + " % " + secondOp + ")"); result.setText("(" + firstOp + " % " + secondOp + ")");
break; break;
case REM_UNSIGNED: case REM_UNSIGNED:
result.setText("(" + firstOpUnsinged + " % " + secondOpUnsigned + ")"); result.setText("(" + typeText + ") (" + firstOpUnsigned + " % " + secondOpUnsigned + ")");
break; break;
case AND: case AND:
result.setText("(" + firstOp + " & " + secondOp + ")"); result.setText("(" + firstOp + " & " + secondOp + ")");
@ -385,7 +390,7 @@ class WasmCRenderingVisitor implements WasmExpressionVisitor {
result.setText("(" + firstOp + " >> " + secondOp + ")"); result.setText("(" + firstOp + " >> " + secondOp + ")");
break; break;
case SHR_UNSIGNED: case SHR_UNSIGNED:
result.setText("(" + firstOpUnsinged + " >> " + secondOp + ")"); result.setText("(" + typeText + ") (" + firstOpUnsigned + " >> " + secondOp + ")");
break; break;
case EQ: case EQ:
result.setText("(" + firstOp + " == " + secondOp + ")"); result.setText("(" + firstOp + " == " + secondOp + ")");
@ -397,25 +402,25 @@ class WasmCRenderingVisitor implements WasmExpressionVisitor {
result.setText("(" + firstOp + " > " + secondOp + ")"); result.setText("(" + firstOp + " > " + secondOp + ")");
break; break;
case GT_UNSIGNED: case GT_UNSIGNED:
result.setText("(" + firstOpUnsinged + " > " + secondOpUnsigned + ")"); result.setText("(" + firstOpUnsigned + " > " + secondOpUnsigned + ")");
break; break;
case GE_SIGNED: case GE_SIGNED:
result.setText("(" + firstOp + " >= " + secondOp + ")"); result.setText("(" + firstOp + " >= " + secondOp + ")");
break; break;
case GE_UNSIGNED: case GE_UNSIGNED:
result.setText("(" + firstOpUnsinged + " >= " + secondOpUnsigned + ")"); result.setText("(" + typeText + ") (" + firstOpUnsigned + " >= " + secondOpUnsigned + ")");
break; break;
case LT_SIGNED: case LT_SIGNED:
result.setText("(" + firstOp + " < " + secondOp + ")"); result.setText("(" + firstOp + " < " + secondOp + ")");
break; break;
case LT_UNSIGNED: case LT_UNSIGNED:
result.setText("(" + firstOpUnsinged + " < " + secondOpUnsigned + ")"); result.setText("(" + typeText + ") (" + firstOpUnsigned + " < " + secondOpUnsigned + ")");
break; break;
case LE_SIGNED: case LE_SIGNED:
result.setText("(" + firstOp + " <= " + secondOp + ")"); result.setText("(" + firstOp + " <= " + secondOp + ")");
break; break;
case LE_UNSIGNED: case LE_UNSIGNED:
result.setText("(" + firstOpUnsinged + " <= " + secondOpUnsigned + ")"); result.setText("(" + typeText + ") (" + firstOpUnsigned + " <= " + secondOpUnsigned + ")");
break; break;
case ROTL: case ROTL:
result.setText("rotl(" + firstOp + ", " + secondOp + ")"); result.setText("rotl(" + firstOp + ", " + secondOp + ")");
@ -424,6 +429,7 @@ class WasmCRenderingVisitor implements WasmExpressionVisitor {
result.setText("rotr(" + firstOp + ", " + secondOp + ")"); result.setText("rotr(" + firstOp + ", " + secondOp + ")");
break; break;
} }
result.setRelocatable(first.isRelocatable() && second.isRelocatable());
} }
value = result; value = result;
@ -431,34 +437,237 @@ class WasmCRenderingVisitor implements WasmExpressionVisitor {
@Override @Override
public void visit(WasmFloatBinary expression) { public void visit(WasmFloatBinary expression) {
WasmType type = requiredType;
WasmType opType = asWasmType(expression.getType());
CExpression result = new CExpression();
requiredType = opType;
expression.getFirst().acceptVisitor(this);
CExpression first = value;
requiredType = opType;
expression.getSecond().acceptVisitor(this);
CExpression second = value;
if (type == null) {
result.getLines().addAll(first.getLines());
result.getLines().addAll(second.getLines());
} else {
result.getLines().addAll(first.getLines());
if (!second.getLines().isEmpty()) {
first = cacheIfNeeded(opType, first, result);
result.getLines().addAll(second.getLines());
}
String firstOp = first.getText();
String secondOp = second.getText();
switch (expression.getOperation()) {
case ADD:
result.setText("(" + firstOp + " + " + secondOp + ")");
break;
case SUB:
result.setText("(" + firstOp + " - " + secondOp + ")");
break;
case MUL:
result.setText("(" + firstOp + " * " + secondOp + ")");
break;
case DIV:
result.setText("(" + firstOp + " / " + secondOp + ")");
break;
case EQ:
result.setText("(" + firstOp + " == " + secondOp + ")");
break;
case NE:
result.setText("(" + firstOp + " != " + secondOp + ")");
break;
case GT:
result.setText("(" + firstOp + " > " + secondOp + ")");
break;
case GE:
result.setText("(" + firstOp + " >= " + secondOp + ")");
break;
case LT:
result.setText("(" + firstOp + " < " + secondOp + ")");
break;
case LE:
result.setText("(" + firstOp + " <= " + secondOp + ")");
break;
case MIN: {
String function = expression.getType() == WasmFloatType.FLOAT32 ? "fminf" : "fmin";
result.setText(function + "(" + firstOp + ", " + secondOp + ")");
break;
}
case MAX: {
String function = expression.getType() == WasmFloatType.FLOAT32 ? "fmaxf" : "fmax";
result.setText(function + "(" + firstOp + ", " + secondOp + ")");
break;
}
}
result.setRelocatable(first.isRelocatable() && second.isRelocatable());
}
value = result;
}
private CExpression cacheIfNeeded(WasmType type, CExpression expression, CExpression target) {
if (expression.isRelocatable()) {
return expression;
}
String var = "tmp_" + temporaryIndex;
target.getLines().add(new CSingleLine(mapType(type) + " " + var + " = " + expression.getText() + ";"));
return CExpression.relocatable(var);
} }
@Override @Override
public void visit(WasmIntUnary expression) { public void visit(WasmIntUnary expression) {
WasmType type = requiredType;
WasmType opType = asWasmType(expression.getType());
CExpression result = new CExpression();
requiredType = opType;
expression.getOperand().acceptVisitor(this);
CExpression operand = value;
result.getLines().addAll(operand.getLines());
if (type != null) {
switch (expression.getOperation()) {
case POPCNT:
result.setText("popcnt(" + operand.getText() + ")");
break;
case CLZ:
result.setText("clz(" + operand.getText() + ")");
break;
case CTZ:
result.setText("ctz(" + operand.getText() + ")");
break;
}
result.setRelocatable(operand.isRelocatable());
}
value = result;
} }
@Override @Override
public void visit(WasmFloatUnary expression) { public void visit(WasmFloatUnary expression) {
WasmType type = requiredType;
WasmType opType = asWasmType(expression.getType());
CExpression result = new CExpression();
requiredType = opType;
expression.getOperand().acceptVisitor(this);
CExpression operand = value;
result.getLines().addAll(operand.getLines());
if (type != null) {
switch (expression.getOperation()) {
case ABS: {
String functionName = expression.getType() == WasmFloatType.FLOAT32 ? "fabsf" : "fabs";
result.setText(functionName + "(" + operand.getText() + ")");
break;
}
case CEIL: {
String functionName = expression.getType() == WasmFloatType.FLOAT32 ? "ceilf" : "ceil";
result.setText(functionName + "(" + operand.getText() + ")");
break;
}
case FLOOR: {
String functionName = expression.getType() == WasmFloatType.FLOAT32 ? "floorf" : "floor";
result.setText(functionName + "(" + operand.getText() + ")");
break;
}
case TRUNC: {
String functionName = expression.getType() == WasmFloatType.FLOAT32 ? "truncf" : "trunc";
result.setText(functionName + "(" + operand.getText() + ")");
break;
}
case NEAREST: {
String functionName = expression.getType() == WasmFloatType.FLOAT32 ? "roundf" : "round";
result.setText(functionName + "(" + operand.getText() + ")");
break;
}
case SQRT: {
String functionName = expression.getType() == WasmFloatType.FLOAT32 ? "sqrtf" : "sqrt";
result.setText(functionName + "(" + operand.getText() + ")");
break;
}
case NEG:
result.setText("(-" + operand.getText() + ")");
break;
case COPYSIGN: {
String functionName = expression.getType() == WasmFloatType.FLOAT32 ? "copysignf" : "copysign";
result.setText(functionName + "(" + operand.getText() + ")");
break;
}
}
result.setRelocatable(operand.isRelocatable());
}
value = result;
} }
@Override @Override
public void visit(WasmConversion expression) { public void visit(WasmConversion expression) {
CExpression result = new CExpression();
WasmType type = requiredType;
expression.getOperand().acceptVisitor(this);
CExpression operand = value;
result.getLines().addAll(operand.getLines());
if (type != null && expression.getSourceType() != expression.getSourceType()) {
switch (expression.getTargetType()) {
case INT32:
if (expression.isSigned()) {
result.setText("(int32_t) " + operand.getText());
} else {
result.setText("(uint32_t) " + operand.getText());
}
break;
case INT64:
if (expression.isSigned()) {
result.setText("(int64_t) " + operand.getText());
} else {
result.setText("(uint64_t) " + operand.getText());
}
break;
case FLOAT32:
if (expression.getSourceType() == WasmType.FLOAT64) {
result.setText("(float) " + operand.getText());
} else if (expression.isSigned()) {
result.setText("(float) (int64_t) " + operand.getText());
} else {
result.setText("(float) (uint64_t) " + operand.getText());
}
break;
case FLOAT64:
if (expression.getSourceType() == WasmType.FLOAT32) {
result.setText("(double) " + operand.getText());
} else if (expression.isSigned()) {
result.setText("(double) (int64_t) " + operand.getText());
} else {
result.setText("(double) (uint64_t) " + operand.getText());
}
break;
}
}
value = operand;
} }
@Override @Override
public void visit(WasmCall expression) { public void visit(WasmCall expression) {
CExpression result = new CExpression(); CExpression result = new CExpression();
WasmType type = requiredType;
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
sb.append(expression.getFunctionName()).append('('); sb.append(expression.getFunctionName()).append('(');
WasmFunction function = module.getFunctions().get(expression.getFunctionName());
translateArguments(expression.getArguments(), function.getParameters(), result, sb);
sb.append(')'); sb.append(')');
result.setText(sb.toString()); result.setText(sb.toString());
if (requiredType == null) { if (type == null) {
reportLocation(expression.getLocation(), result.getLines()); reportLocation(expression.getLocation(), result.getLines());
result.getLines().add(new CSingleLine(result.getText())); result.getLines().add(new CSingleLine(result.getText()));
result.setText(null); result.setText(null);
@ -468,7 +677,65 @@ class WasmCRenderingVisitor implements WasmExpressionVisitor {
@Override @Override
public void visit(WasmIndirectCall expression) { public void visit(WasmIndirectCall expression) {
CExpression result = new CExpression();
WasmType type = requiredType;
StringBuilder sb = new StringBuilder();
sb.append("(*(" + mapType(expression.getReturnType()) + " (*)(");
for (int i = 0; i < expression.getParameterTypes().size(); ++i) {
if (i > 0) {
sb.append(", ");
}
sb.append(mapType(expression.getParameterTypes().get(i)));
}
sb.append(") ");
requiredType = WasmType.INT32;
expression.getSelector().acceptVisitor(this);
value = cacheIfNeeded(WasmType.INT32, value, result);
result.getLines().addAll(value.getLines());
sb.append("wasm_table[" + result.getText() + "])(");
translateArguments(expression.getArguments(), expression.getParameterTypes(), result, sb);
sb.append(")");
if (type == null) {
reportLocation(expression.getLocation(), result.getLines());
result.getLines().add(new CSingleLine(result.getText()));
result.setText(null);
}
value = result;
}
private void translateArguments(List<WasmExpression> wasmArguments, List<WasmType> signature,
CExpression result, StringBuilder sb) {
if (wasmArguments.isEmpty()) {
return;
}
List<CExpression> arguments = new ArrayList<>();
int needsCachingUntil = 0;
for (int i = wasmArguments.size() - 1; i >= 0; --i) {
requiredType = signature.get(i);
wasmArguments.get(i).acceptVisitor(this);
arguments.add(value);
if (!result.getLines().isEmpty() && needsCachingUntil == 0) {
needsCachingUntil = i;
}
}
Collections.reverse(arguments);
for (int i = 0; i < arguments.size(); ++i) {
CExpression argument = arguments.get(i);
result.getLines().addAll(argument.getLines());
if (i < needsCachingUntil) {
argument = cacheIfNeeded(signature.get(i), argument, result);
}
if (i > 0) {
sb.append(", ");
}
sb.append(argument.getText());
}
value = result;
} }
@Override @Override
@ -483,22 +750,116 @@ class WasmCRenderingVisitor implements WasmExpressionVisitor {
@Override @Override
public void visit(WasmLoadInt32 expression) { public void visit(WasmLoadInt32 expression) {
CExpression result = new CExpression();
WasmType type = requiredType;
requiredType = WasmType.INT32;
expression.getIndex();
CExpression index = value;
if (type == null) {
value = index;
return;
}
result.getLines().addAll(index.getLines());
switch (expression.getConvertFrom()) {
case INT8:
result.setText("(int32_t) (int8_t) wasm_heap[" + index + "]");
break;
case UINT8:
result.setText("(int32_t) (uint8_t) wasm_heap[" + index + "]");
break;
case INT16:
result.setText("(int32_t) *((int16_t *) &wasm_heap[" + index + "])");
break;
case UINT16:
result.setText("(int32_t) *((uint16_t *) &wasm_heap[" + index + "])");
break;
case INT32:
result.setText("*((int32_t *) &wasm_heap[" + index + "])");
break;
}
value = result;
} }
@Override @Override
public void visit(WasmLoadInt64 expression) { public void visit(WasmLoadInt64 expression) {
CExpression result = new CExpression();
WasmType type = requiredType;
requiredType = WasmType.INT32;
expression.getIndex();
CExpression index = value;
if (type == null) {
value = index;
return;
}
result.getLines().addAll(index.getLines());
switch (expression.getConvertFrom()) {
case INT8:
result.setText("(int64_t) (int8_t) wasm_heap[" + index + "]");
break;
case UINT8:
result.setText("(int64_t) (uint8_t) wasm_heap[" + index + "]");
break;
case INT16:
result.setText("(int64_t) *((int16_t *) &wasm_heap[" + index + "])");
break;
case UINT16:
result.setText("(int64_t) *((uint16_t *) &wasm_heap[" + index + "])");
break;
case INT32:
result.setText("(int64_t) *((int32_t *) &wasm_heap[" + index + "])");
break;
case UINT32:
result.setText("(int64_t) *((uint32_t *) &wasm_heap[" + index + "])");
break;
case INT64:
result.setText("*((int64_t *) &wasm_heap[" + index + "])");
break;
}
value = result;
} }
@Override @Override
public void visit(WasmLoadFloat32 expression) { public void visit(WasmLoadFloat32 expression) {
CExpression result = new CExpression();
WasmType type = requiredType;
requiredType = WasmType.INT32;
expression.getIndex();
CExpression index = value;
if (type == null) {
value = index;
return;
}
result.getLines().addAll(index.getLines());
result.setText("*((float *) &wasm_heap[" + index + "])");
value = result;
} }
@Override @Override
public void visit(WasmLoadFloat64 expression) { public void visit(WasmLoadFloat64 expression) {
CExpression result = new CExpression();
WasmType type = requiredType;
requiredType = WasmType.INT32;
expression.getIndex();
CExpression index = value;
if (type == null) {
value = index;
return;
}
result.getLines().addAll(index.getLines());
result.setText("*((double *) &wasm_heap[" + index + "])");
value = result;
} }
@Override @Override
@ -532,6 +893,9 @@ class WasmCRenderingVisitor implements WasmExpressionVisitor {
} }
private static String mapType(WasmType type) { private static String mapType(WasmType type) {
if (type == null) {
return "void";
}
switch (type) { switch (type) {
case INT32: case INT32:
return "int32_t"; return "int32_t";
@ -555,6 +919,16 @@ class WasmCRenderingVisitor implements WasmExpressionVisitor {
throw new AssertionError(type.toString()); throw new AssertionError(type.toString());
} }
private static WasmType asWasmType(WasmFloatType type) {
switch (type) {
case FLOAT32:
return WasmType.FLOAT32;
case FLOAT64:
return WasmType.FLOAT64;
}
throw new AssertionError(type.toString());
}
static class BlockInfo { static class BlockInfo {
int index; int index;
String label; String label;