From 199d91d28c4e0714051905023f2e104da5c532e4 Mon Sep 17 00:00:00 2001 From: Alexey Andreev Date: Tue, 23 Aug 2016 23:22:29 +0300 Subject: [PATCH] Further work on WASM tree -> C renderer --- .../backend/wasm/render/CExpression.java | 15 + .../wasm/render/WasmCRenderingVisitor.java | 418 +++++++++++++++++- 2 files changed, 411 insertions(+), 22 deletions(-) diff --git a/core/src/main/java/org/teavm/backend/wasm/render/CExpression.java b/core/src/main/java/org/teavm/backend/wasm/render/CExpression.java index 5e1c2aaf6..abbdeff94 100644 --- a/core/src/main/java/org/teavm/backend/wasm/render/CExpression.java +++ b/core/src/main/java/org/teavm/backend/wasm/render/CExpression.java @@ -20,6 +20,7 @@ import java.util.List; class CExpression { private String text; + private boolean relocatable; private List lines = new ArrayList<>(); public CExpression(String text) { @@ -29,6 +30,14 @@ class CExpression { public CExpression() { } + public boolean isRelocatable() { + return relocatable; + } + + public void setRelocatable(boolean relocatable) { + this.relocatable = relocatable; + } + public String getText() { return text; } @@ -40,4 +49,10 @@ class CExpression { public List getLines() { return lines; } + + public static CExpression relocatable(String text) { + CExpression expression = new CExpression(text); + expression.setRelocatable(true); + return expression; + } } diff --git a/core/src/main/java/org/teavm/backend/wasm/render/WasmCRenderingVisitor.java b/core/src/main/java/org/teavm/backend/wasm/render/WasmCRenderingVisitor.java index ba277f5b5..99ed05543 100644 --- a/core/src/main/java/org/teavm/backend/wasm/render/WasmCRenderingVisitor.java +++ b/core/src/main/java/org/teavm/backend/wasm/render/WasmCRenderingVisitor.java @@ -16,9 +16,12 @@ package org.teavm.backend.wasm.render; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import org.teavm.backend.wasm.model.WasmFunction; +import org.teavm.backend.wasm.model.WasmModule; import org.teavm.backend.wasm.model.WasmType; import org.teavm.backend.wasm.model.expression.WasmBlock; import org.teavm.backend.wasm.model.expression.WasmBranch; @@ -32,6 +35,7 @@ import org.teavm.backend.wasm.model.expression.WasmExpressionVisitor; import org.teavm.backend.wasm.model.expression.WasmFloat32Constant; import org.teavm.backend.wasm.model.expression.WasmFloat64Constant; import org.teavm.backend.wasm.model.expression.WasmFloatBinary; +import org.teavm.backend.wasm.model.expression.WasmFloatType; import org.teavm.backend.wasm.model.expression.WasmFloatUnary; import org.teavm.backend.wasm.model.expression.WasmGetLocal; import org.teavm.backend.wasm.model.expression.WasmIndirectCall; @@ -61,9 +65,11 @@ class WasmCRenderingVisitor implements WasmExpressionVisitor { private int temporaryIndex; private int blockIndex; private WasmType functionType; + private WasmModule module; - public WasmCRenderingVisitor(WasmType functionType) { + public WasmCRenderingVisitor(WasmType functionType, WasmModule module) { this.functionType = functionType; + this.module = module; } public CExpression getValue() { @@ -281,27 +287,27 @@ class WasmCRenderingVisitor implements WasmExpressionVisitor { @Override public void visit(WasmInt32Constant expression) { - value = new CExpression("INT32_C(" + String.valueOf(expression.getValue()) + ")"); + value = CExpression.relocatable("INT32_C(" + String.valueOf(expression.getValue()) + ")"); } @Override public void visit(WasmInt64Constant expression) { - value = new CExpression("INT64_C(" + String.valueOf(expression.getValue()) + ")"); + value = CExpression.relocatable("INT64_C(" + String.valueOf(expression.getValue()) + ")"); } @Override public void visit(WasmFloat32Constant expression) { - value = new CExpression(Float.toHexString(expression.getValue()) + "F"); + value = CExpression.relocatable(Float.toHexString(expression.getValue()) + "F"); } @Override public void visit(WasmFloat64Constant expression) { - value = new CExpression(Double.toHexString(expression.getValue())); + value = CExpression.relocatable(Double.toHexString(expression.getValue())); } @Override public void visit(WasmGetLocal expression) { - value = new CExpression("var_" + expression.getLocal().getIndex()); + value = CExpression.relocatable("var_" + expression.getLocal().getIndex()); } @Override @@ -334,19 +340,18 @@ class WasmCRenderingVisitor implements WasmExpressionVisitor { result.getLines().addAll(first.getLines()); result.getLines().addAll(second.getLines()); } else { - String firstOp = first.getText(); - String secondOp = second.getText(); result.getLines().addAll(first.getLines()); if (!second.getLines().isEmpty()) { - firstOp = "tmp_" + temporaryIndex; - result.getLines().add(new CSingleLine(mapType(opType) + " " + firstOp + " = " - + first.getText() + ";")); + first = cacheIfNeeded(opType, first, result); result.getLines().addAll(second.getLines()); } - String unsingedType = "u" + mapType(opType); - String firstOpUnsinged = "(" + unsingedType + ") " + firstOp; - String secondOpUnsigned = "(" + unsingedType + ") " + secondOp; + String firstOp = first.getText(); + String secondOp = second.getText(); + String typeText = mapType(opType); + String unsignedType = "u" + typeText; + String firstOpUnsigned = "(" + unsignedType + ") " + firstOp; + String secondOpUnsigned = "(" + unsignedType + ") " + secondOp; switch (expression.getOperation()) { case ADD: result.setText("(" + firstOp + " + " + secondOp + ")"); @@ -361,13 +366,13 @@ class WasmCRenderingVisitor implements WasmExpressionVisitor { result.setText("(" + firstOp + " / " + secondOp + ")"); break; case DIV_UNSIGNED: - result.setText("(" + firstOpUnsinged + " / " + secondOpUnsigned + ")"); + result.setText("(" + typeText + ") (" + firstOpUnsigned + " / " + secondOpUnsigned + ")"); break; case REM_SIGNED: result.setText("(" + firstOp + " % " + secondOp + ")"); break; case REM_UNSIGNED: - result.setText("(" + firstOpUnsinged + " % " + secondOpUnsigned + ")"); + result.setText("(" + typeText + ") (" + firstOpUnsigned + " % " + secondOpUnsigned + ")"); break; case AND: result.setText("(" + firstOp + " & " + secondOp + ")"); @@ -385,7 +390,7 @@ class WasmCRenderingVisitor implements WasmExpressionVisitor { result.setText("(" + firstOp + " >> " + secondOp + ")"); break; case SHR_UNSIGNED: - result.setText("(" + firstOpUnsinged + " >> " + secondOp + ")"); + result.setText("(" + typeText + ") (" + firstOpUnsigned + " >> " + secondOp + ")"); break; case EQ: result.setText("(" + firstOp + " == " + secondOp + ")"); @@ -397,25 +402,25 @@ class WasmCRenderingVisitor implements WasmExpressionVisitor { result.setText("(" + firstOp + " > " + secondOp + ")"); break; case GT_UNSIGNED: - result.setText("(" + firstOpUnsinged + " > " + secondOpUnsigned + ")"); + result.setText("(" + firstOpUnsigned + " > " + secondOpUnsigned + ")"); break; case GE_SIGNED: result.setText("(" + firstOp + " >= " + secondOp + ")"); break; case GE_UNSIGNED: - result.setText("(" + firstOpUnsinged + " >= " + secondOpUnsigned + ")"); + result.setText("(" + typeText + ") (" + firstOpUnsigned + " >= " + secondOpUnsigned + ")"); break; case LT_SIGNED: result.setText("(" + firstOp + " < " + secondOp + ")"); break; case LT_UNSIGNED: - result.setText("(" + firstOpUnsinged + " < " + secondOpUnsigned + ")"); + result.setText("(" + typeText + ") (" + firstOpUnsigned + " < " + secondOpUnsigned + ")"); break; case LE_SIGNED: result.setText("(" + firstOp + " <= " + secondOp + ")"); break; case LE_UNSIGNED: - result.setText("(" + firstOpUnsinged + " <= " + secondOpUnsigned + ")"); + result.setText("(" + typeText + ") (" + firstOpUnsigned + " <= " + secondOpUnsigned + ")"); break; case ROTL: result.setText("rotl(" + firstOp + ", " + secondOp + ")"); @@ -424,6 +429,7 @@ class WasmCRenderingVisitor implements WasmExpressionVisitor { result.setText("rotr(" + firstOp + ", " + secondOp + ")"); break; } + result.setRelocatable(first.isRelocatable() && second.isRelocatable()); } value = result; @@ -431,34 +437,237 @@ class WasmCRenderingVisitor implements WasmExpressionVisitor { @Override public void visit(WasmFloatBinary expression) { + WasmType type = requiredType; + WasmType opType = asWasmType(expression.getType()); + CExpression result = new CExpression(); + requiredType = opType; + expression.getFirst().acceptVisitor(this); + CExpression first = value; + + requiredType = opType; + expression.getSecond().acceptVisitor(this); + CExpression second = value; + + if (type == null) { + result.getLines().addAll(first.getLines()); + result.getLines().addAll(second.getLines()); + } else { + result.getLines().addAll(first.getLines()); + if (!second.getLines().isEmpty()) { + first = cacheIfNeeded(opType, first, result); + result.getLines().addAll(second.getLines()); + } + + String firstOp = first.getText(); + String secondOp = second.getText(); + + switch (expression.getOperation()) { + case ADD: + result.setText("(" + firstOp + " + " + secondOp + ")"); + break; + case SUB: + result.setText("(" + firstOp + " - " + secondOp + ")"); + break; + case MUL: + result.setText("(" + firstOp + " * " + secondOp + ")"); + break; + case DIV: + result.setText("(" + firstOp + " / " + secondOp + ")"); + break; + case EQ: + result.setText("(" + firstOp + " == " + secondOp + ")"); + break; + case NE: + result.setText("(" + firstOp + " != " + secondOp + ")"); + break; + case GT: + result.setText("(" + firstOp + " > " + secondOp + ")"); + break; + case GE: + result.setText("(" + firstOp + " >= " + secondOp + ")"); + break; + case LT: + result.setText("(" + firstOp + " < " + secondOp + ")"); + break; + case LE: + result.setText("(" + firstOp + " <= " + secondOp + ")"); + break; + case MIN: { + String function = expression.getType() == WasmFloatType.FLOAT32 ? "fminf" : "fmin"; + result.setText(function + "(" + firstOp + ", " + secondOp + ")"); + break; + } + case MAX: { + String function = expression.getType() == WasmFloatType.FLOAT32 ? "fmaxf" : "fmax"; + result.setText(function + "(" + firstOp + ", " + secondOp + ")"); + break; + } + } + result.setRelocatable(first.isRelocatable() && second.isRelocatable()); + } + + value = result; + } + + private CExpression cacheIfNeeded(WasmType type, CExpression expression, CExpression target) { + if (expression.isRelocatable()) { + return expression; + } + String var = "tmp_" + temporaryIndex; + target.getLines().add(new CSingleLine(mapType(type) + " " + var + " = " + expression.getText() + ";")); + return CExpression.relocatable(var); } @Override public void visit(WasmIntUnary expression) { + WasmType type = requiredType; + WasmType opType = asWasmType(expression.getType()); + CExpression result = new CExpression(); + requiredType = opType; + expression.getOperand().acceptVisitor(this); + CExpression operand = value; + + result.getLines().addAll(operand.getLines()); + if (type != null) { + switch (expression.getOperation()) { + case POPCNT: + result.setText("popcnt(" + operand.getText() + ")"); + break; + case CLZ: + result.setText("clz(" + operand.getText() + ")"); + break; + case CTZ: + result.setText("ctz(" + operand.getText() + ")"); + break; + } + result.setRelocatable(operand.isRelocatable()); + } + + value = result; } @Override public void visit(WasmFloatUnary expression) { + WasmType type = requiredType; + WasmType opType = asWasmType(expression.getType()); + CExpression result = new CExpression(); + requiredType = opType; + expression.getOperand().acceptVisitor(this); + CExpression operand = value; + + result.getLines().addAll(operand.getLines()); + if (type != null) { + switch (expression.getOperation()) { + case ABS: { + String functionName = expression.getType() == WasmFloatType.FLOAT32 ? "fabsf" : "fabs"; + result.setText(functionName + "(" + operand.getText() + ")"); + break; + } + case CEIL: { + String functionName = expression.getType() == WasmFloatType.FLOAT32 ? "ceilf" : "ceil"; + result.setText(functionName + "(" + operand.getText() + ")"); + break; + } + case FLOOR: { + String functionName = expression.getType() == WasmFloatType.FLOAT32 ? "floorf" : "floor"; + result.setText(functionName + "(" + operand.getText() + ")"); + break; + } + case TRUNC: { + String functionName = expression.getType() == WasmFloatType.FLOAT32 ? "truncf" : "trunc"; + result.setText(functionName + "(" + operand.getText() + ")"); + break; + } + case NEAREST: { + String functionName = expression.getType() == WasmFloatType.FLOAT32 ? "roundf" : "round"; + result.setText(functionName + "(" + operand.getText() + ")"); + break; + } + case SQRT: { + String functionName = expression.getType() == WasmFloatType.FLOAT32 ? "sqrtf" : "sqrt"; + result.setText(functionName + "(" + operand.getText() + ")"); + break; + } + case NEG: + result.setText("(-" + operand.getText() + ")"); + break; + case COPYSIGN: { + String functionName = expression.getType() == WasmFloatType.FLOAT32 ? "copysignf" : "copysign"; + result.setText(functionName + "(" + operand.getText() + ")"); + break; + } + } + result.setRelocatable(operand.isRelocatable()); + } + + value = result; } @Override public void visit(WasmConversion expression) { + CExpression result = new CExpression(); + WasmType type = requiredType; + expression.getOperand().acceptVisitor(this); + CExpression operand = value; + + result.getLines().addAll(operand.getLines()); + if (type != null && expression.getSourceType() != expression.getSourceType()) { + switch (expression.getTargetType()) { + case INT32: + if (expression.isSigned()) { + result.setText("(int32_t) " + operand.getText()); + } else { + result.setText("(uint32_t) " + operand.getText()); + } + break; + case INT64: + if (expression.isSigned()) { + result.setText("(int64_t) " + operand.getText()); + } else { + result.setText("(uint64_t) " + operand.getText()); + } + break; + case FLOAT32: + if (expression.getSourceType() == WasmType.FLOAT64) { + result.setText("(float) " + operand.getText()); + } else if (expression.isSigned()) { + result.setText("(float) (int64_t) " + operand.getText()); + } else { + result.setText("(float) (uint64_t) " + operand.getText()); + } + break; + case FLOAT64: + if (expression.getSourceType() == WasmType.FLOAT32) { + result.setText("(double) " + operand.getText()); + } else if (expression.isSigned()) { + result.setText("(double) (int64_t) " + operand.getText()); + } else { + result.setText("(double) (uint64_t) " + operand.getText()); + } + break; + } + } + + value = operand; } @Override public void visit(WasmCall expression) { CExpression result = new CExpression(); + WasmType type = requiredType; StringBuilder sb = new StringBuilder(); sb.append(expression.getFunctionName()).append('('); + WasmFunction function = module.getFunctions().get(expression.getFunctionName()); + translateArguments(expression.getArguments(), function.getParameters(), result, sb); sb.append(')'); result.setText(sb.toString()); - if (requiredType == null) { + if (type == null) { reportLocation(expression.getLocation(), result.getLines()); result.getLines().add(new CSingleLine(result.getText())); result.setText(null); @@ -468,7 +677,65 @@ class WasmCRenderingVisitor implements WasmExpressionVisitor { @Override public void visit(WasmIndirectCall expression) { + CExpression result = new CExpression(); + WasmType type = requiredType; + StringBuilder sb = new StringBuilder(); + sb.append("(*(" + mapType(expression.getReturnType()) + " (*)("); + for (int i = 0; i < expression.getParameterTypes().size(); ++i) { + if (i > 0) { + sb.append(", "); + } + sb.append(mapType(expression.getParameterTypes().get(i))); + } + sb.append(") "); + + requiredType = WasmType.INT32; + expression.getSelector().acceptVisitor(this); + value = cacheIfNeeded(WasmType.INT32, value, result); + result.getLines().addAll(value.getLines()); + sb.append("wasm_table[" + result.getText() + "])("); + translateArguments(expression.getArguments(), expression.getParameterTypes(), result, sb); + sb.append(")"); + + if (type == null) { + reportLocation(expression.getLocation(), result.getLines()); + result.getLines().add(new CSingleLine(result.getText())); + result.setText(null); + } + value = result; + } + + private void translateArguments(List wasmArguments, List signature, + CExpression result, StringBuilder sb) { + if (wasmArguments.isEmpty()) { + return; + } + List arguments = new ArrayList<>(); + int needsCachingUntil = 0; + for (int i = wasmArguments.size() - 1; i >= 0; --i) { + requiredType = signature.get(i); + wasmArguments.get(i).acceptVisitor(this); + arguments.add(value); + if (!result.getLines().isEmpty() && needsCachingUntil == 0) { + needsCachingUntil = i; + } + } + Collections.reverse(arguments); + + for (int i = 0; i < arguments.size(); ++i) { + CExpression argument = arguments.get(i); + result.getLines().addAll(argument.getLines()); + if (i < needsCachingUntil) { + argument = cacheIfNeeded(signature.get(i), argument, result); + } + if (i > 0) { + sb.append(", "); + } + sb.append(argument.getText()); + } + + value = result; } @Override @@ -483,22 +750,116 @@ class WasmCRenderingVisitor implements WasmExpressionVisitor { @Override public void visit(WasmLoadInt32 expression) { + CExpression result = new CExpression(); + WasmType type = requiredType; + requiredType = WasmType.INT32; + expression.getIndex(); + CExpression index = value; + if (type == null) { + value = index; + return; + } + + result.getLines().addAll(index.getLines()); + switch (expression.getConvertFrom()) { + case INT8: + result.setText("(int32_t) (int8_t) wasm_heap[" + index + "]"); + break; + case UINT8: + result.setText("(int32_t) (uint8_t) wasm_heap[" + index + "]"); + break; + case INT16: + result.setText("(int32_t) *((int16_t *) &wasm_heap[" + index + "])"); + break; + case UINT16: + result.setText("(int32_t) *((uint16_t *) &wasm_heap[" + index + "])"); + break; + case INT32: + result.setText("*((int32_t *) &wasm_heap[" + index + "])"); + break; + } + + value = result; } @Override public void visit(WasmLoadInt64 expression) { + CExpression result = new CExpression(); + WasmType type = requiredType; + requiredType = WasmType.INT32; + expression.getIndex(); + CExpression index = value; + if (type == null) { + value = index; + return; + } + + result.getLines().addAll(index.getLines()); + switch (expression.getConvertFrom()) { + case INT8: + result.setText("(int64_t) (int8_t) wasm_heap[" + index + "]"); + break; + case UINT8: + result.setText("(int64_t) (uint8_t) wasm_heap[" + index + "]"); + break; + case INT16: + result.setText("(int64_t) *((int16_t *) &wasm_heap[" + index + "])"); + break; + case UINT16: + result.setText("(int64_t) *((uint16_t *) &wasm_heap[" + index + "])"); + break; + case INT32: + result.setText("(int64_t) *((int32_t *) &wasm_heap[" + index + "])"); + break; + case UINT32: + result.setText("(int64_t) *((uint32_t *) &wasm_heap[" + index + "])"); + break; + case INT64: + result.setText("*((int64_t *) &wasm_heap[" + index + "])"); + break; + } + + value = result; } @Override public void visit(WasmLoadFloat32 expression) { + CExpression result = new CExpression(); + WasmType type = requiredType; + requiredType = WasmType.INT32; + expression.getIndex(); + CExpression index = value; + if (type == null) { + value = index; + return; + } + + result.getLines().addAll(index.getLines()); + result.setText("*((float *) &wasm_heap[" + index + "])"); + + value = result; } @Override public void visit(WasmLoadFloat64 expression) { + CExpression result = new CExpression(); + WasmType type = requiredType; + requiredType = WasmType.INT32; + expression.getIndex(); + CExpression index = value; + if (type == null) { + value = index; + return; + } + + result.getLines().addAll(index.getLines()); + result.setText("*((double *) &wasm_heap[" + index + "])"); + + value = result; } @Override @@ -532,6 +893,9 @@ class WasmCRenderingVisitor implements WasmExpressionVisitor { } private static String mapType(WasmType type) { + if (type == null) { + return "void"; + } switch (type) { case INT32: return "int32_t"; @@ -555,6 +919,16 @@ class WasmCRenderingVisitor implements WasmExpressionVisitor { throw new AssertionError(type.toString()); } + private static WasmType asWasmType(WasmFloatType type) { + switch (type) { + case FLOAT32: + return WasmType.FLOAT32; + case FLOAT64: + return WasmType.FLOAT64; + } + throw new AssertionError(type.toString()); + } + static class BlockInfo { int index; String label;