diff --git a/core/src/main/java/org/teavm/backend/wasm/generate/WasmGenerationVisitor.java b/core/src/main/java/org/teavm/backend/wasm/generate/WasmGenerationVisitor.java index 8908025af..c76c6f934 100644 --- a/core/src/main/java/org/teavm/backend/wasm/generate/WasmGenerationVisitor.java +++ b/core/src/main/java/org/teavm/backend/wasm/generate/WasmGenerationVisitor.java @@ -121,6 +121,7 @@ import org.teavm.runtime.ShadowStack; class WasmGenerationVisitor implements StatementVisitor, ExprVisitor { private static FieldReference tagField = new FieldReference(RuntimeClass.class.getName(), "tag"); + private static final int SWITCH_TABLE_THRESHOLD = 256; private WasmGenerationContext context; private WasmClassGenerator classGenerator; private WasmTypeInference typeInference; @@ -700,8 +701,6 @@ class WasmGenerationVisitor implements StatementVisitor, ExprVisitor { @Override public void visit(SwitchStatement statement) { - WasmBlock defaultBlock = new WasmBlock(false); - int min = statement.getClauses().stream() .flatMapToInt(clause -> Arrays.stream(clause.getConditions())) .min().orElse(0); @@ -709,28 +708,23 @@ class WasmGenerationVisitor implements StatementVisitor, ExprVisitor { .flatMapToInt(clause -> Arrays.stream(clause.getConditions())) .max().orElse(0); + WasmBlock defaultBlock = new WasmBlock(false); breakTargets.put(statement, defaultBlock); IdentifiedStatement oldBreakTarget = currentBreakTarget; currentBreakTarget = statement; WasmBlock wrapper = new WasmBlock(false); accept(statement.getValue()); - if (min > 0) { - result = new WasmIntBinary(WasmIntType.INT32, WasmIntBinaryOperation.SUB, result, - new WasmInt32Constant(min)); - } + WasmExpression condition = result; + WasmBlock initialWrapper = wrapper; - WasmSwitch wasmSwitch = new WasmSwitch(result, wrapper); - wrapper.getBody().add(wasmSwitch); - WasmBlock[] targets = new WasmBlock[max - min + 1]; - - for (SwitchClause clause : statement.getClauses()) { + List clauses = statement.getClauses(); + WasmBlock[] targets = new WasmBlock[clauses.size()]; + for (int i = 0; i < clauses.size(); i++) { + SwitchClause clause = clauses.get(i); WasmBlock caseBlock = new WasmBlock(false); caseBlock.getBody().add(wrapper); - - for (int condition : clause.getConditions()) { - targets[condition - min] = wrapper; - } + targets[i] = wrapper; for (Statement part : clause.getBody()) { accept(part); @@ -748,11 +742,13 @@ class WasmGenerationVisitor implements StatementVisitor, ExprVisitor { defaultBlock.getBody().add(result); } } - wasmSwitch.setDefaultTarget(wrapper); + WasmBlock defaultTarget = wrapper; wrapper = defaultBlock; - for (WasmBlock target : targets) { - wasmSwitch.getTargets().add(target != null ? target : wasmSwitch.getDefaultTarget()); + if (max - min >= SWITCH_TABLE_THRESHOLD) { + translateSwitchToBinarySearch(statement, condition, initialWrapper, defaultTarget, targets); + } else { + translateSwitchToWasmSwitch(statement, condition, initialWrapper, defaultTarget, targets, min, max); } breakTargets.remove(statement); @@ -761,6 +757,83 @@ class WasmGenerationVisitor implements StatementVisitor, ExprVisitor { result = wrapper; } + private void translateSwitchToBinarySearch(SwitchStatement statement, WasmExpression condition, + WasmBlock initialWrapper, WasmBlock defaultTarget, WasmBlock[] targets) { + List entries = new ArrayList<>(); + for (int i = 0; i < statement.getClauses().size(); i++) { + SwitchClause clause = statement.getClauses().get(i); + for (int label : clause.getConditions()) { + entries.add(new TableEntry(label, targets[i])); + } + } + entries.sort(Comparator.comparingInt(entry -> entry.label)); + + WasmLocal conditionVar = getTemporary(WasmType.INT32); + initialWrapper.getBody().add(new WasmSetLocal(conditionVar, condition)); + + generateBinarySearch(entries, 0, entries.size() - 1, initialWrapper, defaultTarget, conditionVar); + } + + private void generateBinarySearch(List entries, int lower, int upper, WasmBlock consumer, + WasmBlock defaultTarget, WasmLocal conditionVar) { + if (upper - lower == 0) { + int label = entries.get(lower).label; + WasmExpression condition = new WasmIntBinary(WasmIntType.INT32, WasmIntBinaryOperation.EQ, + new WasmGetLocal(conditionVar), new WasmInt32Constant(label)); + WasmConditional conditional = new WasmConditional(condition); + consumer.getBody().add(conditional); + + conditional.getThenBlock().getBody().add(new WasmBreak(entries.get(lower).target)); + conditional.getElseBlock().getBody().add(new WasmBreak(defaultTarget)); + } else if (upper - lower <= 0) { + consumer.getBody().add(new WasmBreak(defaultTarget)); + } else { + int mid = (upper + lower) / 2; + int label = entries.get(mid).label; + WasmExpression condition = new WasmIntBinary(WasmIntType.INT32, WasmIntBinaryOperation.GT_UNSIGNED, + new WasmGetLocal(conditionVar), new WasmInt32Constant(label)); + WasmConditional conditional = new WasmConditional(condition); + consumer.getBody().add(conditional); + + generateBinarySearch(entries, mid + 1, upper, conditional.getThenBlock(), defaultTarget, conditionVar); + generateBinarySearch(entries, lower, mid, conditional.getElseBlock(), defaultTarget, conditionVar); + } + } + + static class TableEntry { + final int label; + final WasmBlock target; + + public TableEntry(int label, WasmBlock target) { + this.label = label; + this.target = target; + } + } + + private void translateSwitchToWasmSwitch(SwitchStatement statement, WasmExpression condition, + WasmBlock initialWrapper, WasmBlock defaultTarget, WasmBlock[] targets, int min, int max) { + if (min > 0) { + condition = new WasmIntBinary(WasmIntType.INT32, WasmIntBinaryOperation.SUB, condition, + new WasmInt32Constant(min)); + } + + WasmSwitch wasmSwitch = new WasmSwitch(condition, initialWrapper); + initialWrapper.getBody().add(wasmSwitch); + wasmSwitch.setDefaultTarget(defaultTarget); + + WasmBlock[] expandedTargets = new WasmBlock[max - min + 1]; + for (int i = 0; i < statement.getClauses().size(); i++) { + SwitchClause clause = statement.getClauses().get(i); + for (int label : clause.getConditions()) { + expandedTargets[label - min] = targets[i]; + } + } + + for (WasmBlock target : expandedTargets) { + wasmSwitch.getTargets().add(target != null ? target : wasmSwitch.getDefaultTarget()); + } + } + @Override public void visit(UnwrapArrayExpr expr) { accept(expr.getArray());