WASM: when lookupswitch table has labels that diff by more 256,

translate it to binary search code.
See #261
This commit is contained in:
Alexey Andreev 2017-04-09 23:13:26 +03:00
parent ef1618ec36
commit 248a49c7dd

View File

@ -121,6 +121,7 @@ import org.teavm.runtime.ShadowStack;
class WasmGenerationVisitor implements StatementVisitor, ExprVisitor {
private static FieldReference tagField = new FieldReference(RuntimeClass.class.getName(), "tag");
private static final int SWITCH_TABLE_THRESHOLD = 256;
private WasmGenerationContext context;
private WasmClassGenerator classGenerator;
private WasmTypeInference typeInference;
@ -700,8 +701,6 @@ class WasmGenerationVisitor implements StatementVisitor, ExprVisitor {
@Override
public void visit(SwitchStatement statement) {
WasmBlock defaultBlock = new WasmBlock(false);
int min = statement.getClauses().stream()
.flatMapToInt(clause -> Arrays.stream(clause.getConditions()))
.min().orElse(0);
@ -709,28 +708,23 @@ class WasmGenerationVisitor implements StatementVisitor, ExprVisitor {
.flatMapToInt(clause -> Arrays.stream(clause.getConditions()))
.max().orElse(0);
WasmBlock defaultBlock = new WasmBlock(false);
breakTargets.put(statement, defaultBlock);
IdentifiedStatement oldBreakTarget = currentBreakTarget;
currentBreakTarget = statement;
WasmBlock wrapper = new WasmBlock(false);
accept(statement.getValue());
if (min > 0) {
result = new WasmIntBinary(WasmIntType.INT32, WasmIntBinaryOperation.SUB, result,
new WasmInt32Constant(min));
}
WasmExpression condition = result;
WasmBlock initialWrapper = wrapper;
WasmSwitch wasmSwitch = new WasmSwitch(result, wrapper);
wrapper.getBody().add(wasmSwitch);
WasmBlock[] targets = new WasmBlock[max - min + 1];
for (SwitchClause clause : statement.getClauses()) {
List<SwitchClause> clauses = statement.getClauses();
WasmBlock[] targets = new WasmBlock[clauses.size()];
for (int i = 0; i < clauses.size(); i++) {
SwitchClause clause = clauses.get(i);
WasmBlock caseBlock = new WasmBlock(false);
caseBlock.getBody().add(wrapper);
for (int condition : clause.getConditions()) {
targets[condition - min] = wrapper;
}
targets[i] = wrapper;
for (Statement part : clause.getBody()) {
accept(part);
@ -748,11 +742,13 @@ class WasmGenerationVisitor implements StatementVisitor, ExprVisitor {
defaultBlock.getBody().add(result);
}
}
wasmSwitch.setDefaultTarget(wrapper);
WasmBlock defaultTarget = wrapper;
wrapper = defaultBlock;
for (WasmBlock target : targets) {
wasmSwitch.getTargets().add(target != null ? target : wasmSwitch.getDefaultTarget());
if (max - min >= SWITCH_TABLE_THRESHOLD) {
translateSwitchToBinarySearch(statement, condition, initialWrapper, defaultTarget, targets);
} else {
translateSwitchToWasmSwitch(statement, condition, initialWrapper, defaultTarget, targets, min, max);
}
breakTargets.remove(statement);
@ -761,6 +757,83 @@ class WasmGenerationVisitor implements StatementVisitor, ExprVisitor {
result = wrapper;
}
private void translateSwitchToBinarySearch(SwitchStatement statement, WasmExpression condition,
WasmBlock initialWrapper, WasmBlock defaultTarget, WasmBlock[] targets) {
List<TableEntry> entries = new ArrayList<>();
for (int i = 0; i < statement.getClauses().size(); i++) {
SwitchClause clause = statement.getClauses().get(i);
for (int label : clause.getConditions()) {
entries.add(new TableEntry(label, targets[i]));
}
}
entries.sort(Comparator.comparingInt(entry -> entry.label));
WasmLocal conditionVar = getTemporary(WasmType.INT32);
initialWrapper.getBody().add(new WasmSetLocal(conditionVar, condition));
generateBinarySearch(entries, 0, entries.size() - 1, initialWrapper, defaultTarget, conditionVar);
}
private void generateBinarySearch(List<TableEntry> entries, int lower, int upper, WasmBlock consumer,
WasmBlock defaultTarget, WasmLocal conditionVar) {
if (upper - lower == 0) {
int label = entries.get(lower).label;
WasmExpression condition = new WasmIntBinary(WasmIntType.INT32, WasmIntBinaryOperation.EQ,
new WasmGetLocal(conditionVar), new WasmInt32Constant(label));
WasmConditional conditional = new WasmConditional(condition);
consumer.getBody().add(conditional);
conditional.getThenBlock().getBody().add(new WasmBreak(entries.get(lower).target));
conditional.getElseBlock().getBody().add(new WasmBreak(defaultTarget));
} else if (upper - lower <= 0) {
consumer.getBody().add(new WasmBreak(defaultTarget));
} else {
int mid = (upper + lower) / 2;
int label = entries.get(mid).label;
WasmExpression condition = new WasmIntBinary(WasmIntType.INT32, WasmIntBinaryOperation.GT_UNSIGNED,
new WasmGetLocal(conditionVar), new WasmInt32Constant(label));
WasmConditional conditional = new WasmConditional(condition);
consumer.getBody().add(conditional);
generateBinarySearch(entries, mid + 1, upper, conditional.getThenBlock(), defaultTarget, conditionVar);
generateBinarySearch(entries, lower, mid, conditional.getElseBlock(), defaultTarget, conditionVar);
}
}
static class TableEntry {
final int label;
final WasmBlock target;
public TableEntry(int label, WasmBlock target) {
this.label = label;
this.target = target;
}
}
private void translateSwitchToWasmSwitch(SwitchStatement statement, WasmExpression condition,
WasmBlock initialWrapper, WasmBlock defaultTarget, WasmBlock[] targets, int min, int max) {
if (min > 0) {
condition = new WasmIntBinary(WasmIntType.INT32, WasmIntBinaryOperation.SUB, condition,
new WasmInt32Constant(min));
}
WasmSwitch wasmSwitch = new WasmSwitch(condition, initialWrapper);
initialWrapper.getBody().add(wasmSwitch);
wasmSwitch.setDefaultTarget(defaultTarget);
WasmBlock[] expandedTargets = new WasmBlock[max - min + 1];
for (int i = 0; i < statement.getClauses().size(); i++) {
SwitchClause clause = statement.getClauses().get(i);
for (int label : clause.getConditions()) {
expandedTargets[label - min] = targets[i];
}
}
for (WasmBlock target : expandedTargets) {
wasmSwitch.getTargets().add(target != null ? target : wasmSwitch.getDefaultTarget());
}
}
@Override
public void visit(UnwrapArrayExpr expr) {
accept(expr.getArray());