From f11a5474d0b25646aa81ddb44563816e71baebec Mon Sep 17 00:00:00 2001 From: Alexey Andreev Date: Sat, 30 Nov 2024 18:50:43 +0100 Subject: [PATCH] wasm: use non-trapping conversion instructions when casting floats and doubles to ints and longs Fix #976 --- .../wasm/disasm/DisassemblyCodeListener.java | 35 +++++---- .../methods/BaseWasmGenerationVisitor.java | 6 +- .../wasm/model/expression/WasmConversion.java | 9 +++ .../backend/wasm/parser/CodeListener.java | 3 +- .../teavm/backend/wasm/parser/CodeParser.java | 73 +++++++++++++------ .../render/WasmBinaryRenderingVisitor.java | 20 ++++- .../org/teavm/vm/NumericConversionTest.java | 66 +++++++++++++++++ 7 files changed, 168 insertions(+), 44 deletions(-) create mode 100644 tests/src/test/java/org/teavm/vm/NumericConversionTest.java diff --git a/core/src/main/java/org/teavm/backend/wasm/disasm/DisassemblyCodeListener.java b/core/src/main/java/org/teavm/backend/wasm/disasm/DisassemblyCodeListener.java index 11c1785ce..53222e16d 100644 --- a/core/src/main/java/org/teavm/backend/wasm/disasm/DisassemblyCodeListener.java +++ b/core/src/main/java/org/teavm/backend/wasm/disasm/DisassemblyCodeListener.java @@ -641,7 +641,8 @@ public class DisassemblyCodeListener extends BaseDisassemblyListener implements } @Override - public void convert(WasmNumType sourceType, WasmNumType targetType, boolean signed, boolean reinterpret) { + public void convert(WasmNumType sourceType, WasmNumType targetType, boolean signed, boolean reinterpret, + boolean nonTrapping) { switch (targetType) { case INT32: writer.write("i32."); @@ -649,18 +650,20 @@ public class DisassemblyCodeListener extends BaseDisassemblyListener implements case FLOAT32: if (reinterpret) { writer.write("reinterpret_f32"); - } else if (signed) { - writer.write("trunc_f32_s"); } else { - writer.write("trunc_f32_u"); + writer.write("trunc_"); + if (nonTrapping) { + writer.write("sat_"); + } + writer.write("f32_").write(signed ? "s" : "u"); } break; case FLOAT64: - if (signed) { - writer.write("trunc_f64_s"); - } else { - writer.write("trunc_f64_u"); + writer.write("trunc_"); + if (nonTrapping) { + writer.write("sat_"); } + writer.write("f64_").write(signed ? "s" : "u"); break; case INT64: writer.write("wrap_i64"); @@ -674,19 +677,21 @@ public class DisassemblyCodeListener extends BaseDisassemblyListener implements writer.write("i64."); switch (sourceType) { case FLOAT32: - if (signed) { - writer.write("trunc_f32_s"); - } else { - writer.write("trunc_f32_u"); + writer.write("trunc_"); + if (nonTrapping) { + writer.write("sat_"); } + writer.write("f32_").write(signed ? "s" : "u"); break; case FLOAT64: if (reinterpret) { writer.write("reinterpret_f64"); - } else if (signed) { - writer.write("trunc_f64_s"); } else { - writer.write("trunc_f64_u"); + writer.write("trunc_"); + if (nonTrapping) { + writer.write("sat_"); + } + writer.write("f64_").write(signed ? "s" : "u"); } break; case INT32: diff --git a/core/src/main/java/org/teavm/backend/wasm/generate/common/methods/BaseWasmGenerationVisitor.java b/core/src/main/java/org/teavm/backend/wasm/generate/common/methods/BaseWasmGenerationVisitor.java index 07e7e920d..d9df74a38 100644 --- a/core/src/main/java/org/teavm/backend/wasm/generate/common/methods/BaseWasmGenerationVisitor.java +++ b/core/src/main/java/org/teavm/backend/wasm/generate/common/methods/BaseWasmGenerationVisitor.java @@ -1172,9 +1172,11 @@ public abstract class BaseWasmGenerationVisitor implements StatementVisitor, Exp @Override public void visit(PrimitiveCastExpr expr) { accept(expr.getValue()); - result = new WasmConversion(WasmGeneratorUtil.mapType(expr.getSource()), + var conversion = new WasmConversion(WasmGeneratorUtil.mapType(expr.getSource()), WasmGeneratorUtil.mapType(expr.getTarget()), true, result); - result.setLocation(expr.getLocation()); + conversion.setNonTrapping(true); + conversion.setLocation(expr.getLocation()); + result = conversion; } @Override diff --git a/core/src/main/java/org/teavm/backend/wasm/model/expression/WasmConversion.java b/core/src/main/java/org/teavm/backend/wasm/model/expression/WasmConversion.java index 949a8bf5e..d387adb13 100644 --- a/core/src/main/java/org/teavm/backend/wasm/model/expression/WasmConversion.java +++ b/core/src/main/java/org/teavm/backend/wasm/model/expression/WasmConversion.java @@ -24,6 +24,7 @@ public class WasmConversion extends WasmExpression { private boolean signed; private WasmExpression operand; private boolean reinterpret; + private boolean nonTrapping; public WasmConversion(WasmNumType sourceType, WasmNumType targetType, boolean signed, WasmExpression operand) { Objects.requireNonNull(sourceType); @@ -78,6 +79,14 @@ public class WasmConversion extends WasmExpression { this.operand = operand; } + public boolean isNonTrapping() { + return nonTrapping; + } + + public void setNonTrapping(boolean nonTrapping) { + this.nonTrapping = nonTrapping; + } + @Override public void acceptVisitor(WasmExpressionVisitor visitor) { visitor.visit(this); diff --git a/core/src/main/java/org/teavm/backend/wasm/parser/CodeListener.java b/core/src/main/java/org/teavm/backend/wasm/parser/CodeListener.java index 7876c43fc..91de65675 100644 --- a/core/src/main/java/org/teavm/backend/wasm/parser/CodeListener.java +++ b/core/src/main/java/org/teavm/backend/wasm/parser/CodeListener.java @@ -112,7 +112,8 @@ public interface CodeListener { default void storeFloat64(int align, int offset) { } - default void convert(WasmNumType sourceType, WasmNumType targetType, boolean signed, boolean reinterpret) { + default void convert(WasmNumType sourceType, WasmNumType targetType, boolean signed, boolean reinterpret, + boolean nonTrapping) { } default void memoryGrow() { diff --git a/core/src/main/java/org/teavm/backend/wasm/parser/CodeParser.java b/core/src/main/java/org/teavm/backend/wasm/parser/CodeParser.java index ec9766862..d95fda18b 100644 --- a/core/src/main/java/org/teavm/backend/wasm/parser/CodeParser.java +++ b/core/src/main/java/org/teavm/backend/wasm/parser/CodeParser.java @@ -516,76 +516,76 @@ public class CodeParser extends BaseSectionParser { break; case 0xA7: - codeListener.convert(WasmNumType.INT64, WasmNumType.INT32, false, false); + codeListener.convert(WasmNumType.INT64, WasmNumType.INT32, false, false, false); break; case 0xA8: - codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT32, false, false); + codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT32, false, false, false); break; case 0xA9: - codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT32, true, false); + codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT32, true, false, false); break; case 0xAA: - codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT32, false, false); + codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT32, false, false, false); break; case 0xAB: - codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT32, true, false); + codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT32, true, false, false); break; case 0xAC: - codeListener.convert(WasmNumType.INT32, WasmNumType.INT64, false, false); + codeListener.convert(WasmNumType.INT32, WasmNumType.INT64, false, false, false); break; case 0xAD: - codeListener.convert(WasmNumType.INT32, WasmNumType.INT64, true, false); + codeListener.convert(WasmNumType.INT32, WasmNumType.INT64, true, false, false); break; case 0xAE: - codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT64, false, false); + codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT64, false, false, false); break; case 0xAF: - codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT64, true, false); + codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT64, true, false, false); break; case 0xB0: - codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT64, false, false); + codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT64, false, false, false); break; case 0xB1: - codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT64, true, false); + codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT64, true, false, false); break; case 0xB2: - codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT32, false, false); + codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT32, false, false, false); break; case 0xB3: - codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT32, true, false); + codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT32, true, false, false); break; case 0xB4: - codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT32, false, false); + codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT32, false, false, false); break; case 0xB5: - codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT32, true, false); + codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT32, true, false, false); break; case 0xB6: - codeListener.convert(WasmNumType.FLOAT64, WasmNumType.FLOAT32, true, false); + codeListener.convert(WasmNumType.FLOAT64, WasmNumType.FLOAT32, true, false, false); break; case 0xB7: - codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT64, false, false); + codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT64, false, false, false); break; case 0xB8: - codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT64, true, false); + codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT64, true, false, false); break; case 0xB9: - codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT64, false, false); + codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT64, false, false, false); break; case 0xBA: - codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT64, true, false); + codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT64, true, false, false); break; case 0xBC: - codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT32, false, true); + codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT32, false, true, false); break; case 0xBD: - codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT64, false, true); + codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT64, false, true, false); break; case 0xBE: - codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT32, false, true); + codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT32, false, true, false); break; case 0xBF: - codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT64, false, true); + codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT64, false, true, false); break; case 0xD0: @@ -623,6 +623,31 @@ public class CodeParser extends BaseSectionParser { private boolean parseExtExpr() { switch (readLEB()) { + case 0: + codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT32, true, false, true); + return true; + case 1: + codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT32, false, false, true); + return true; + case 2: + codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT64, true, false, true); + return true; + case 3: + codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT64, false, false, true); + return true; + case 4: + codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT32, true, false, true); + return true; + case 5: + codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT32, false, false, true); + return true; + case 6: + codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT64, true, false, true); + return true; + case 7: + codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT64, false, false, true); + return true; + case 10: { if (reader.data[reader.ptr++] != 0 || reader.data[reader.ptr++] != 0) { return false; diff --git a/core/src/main/java/org/teavm/backend/wasm/render/WasmBinaryRenderingVisitor.java b/core/src/main/java/org/teavm/backend/wasm/render/WasmBinaryRenderingVisitor.java index 1823586c6..536fc6ddf 100644 --- a/core/src/main/java/org/teavm/backend/wasm/render/WasmBinaryRenderingVisitor.java +++ b/core/src/main/java/org/teavm/backend/wasm/render/WasmBinaryRenderingVisitor.java @@ -839,12 +839,20 @@ class WasmBinaryRenderingVisitor implements WasmExpressionVisitor { case INT32: if (expression.isReinterpret()) { writer.writeByte(0xBC); + } else if (expression.isNonTrapping()) { + writer.writeByte(0xFC); + writer.writeByte(expression.isSigned() ? 0 : 1); } else { writer.writeByte(expression.isSigned() ? 0xA8 : 0xA9); } break; case INT64: - writer.writeByte(expression.isSigned() ? 0xAE : 0xAF); + if (expression.isNonTrapping()) { + writer.writeByte(0xFC); + writer.writeByte(expression.isSigned() ? 4 : 5); + } else { + writer.writeByte(expression.isSigned() ? 0xAE : 0xAF); + } break; case FLOAT32: break; @@ -856,11 +864,19 @@ class WasmBinaryRenderingVisitor implements WasmExpressionVisitor { case FLOAT64: switch (expression.getTargetType()) { case INT32: - writer.writeByte(expression.isSigned() ? 0xAA : 0xAB); + if (expression.isNonTrapping()) { + writer.writeByte(0xFC); + writer.writeByte(expression.isSigned() ? 2 : 3); + } else { + writer.writeByte(expression.isSigned() ? 0xAA : 0xAB); + } break; case INT64: if (expression.isReinterpret()) { writer.writeByte(0xBD); + } else if (expression.isNonTrapping()) { + writer.writeByte(0xFC); + writer.writeByte(expression.isSigned() ? 6 : 7); } else { writer.writeByte(expression.isSigned() ? 0xB0 : 0xB1); } diff --git a/tests/src/test/java/org/teavm/vm/NumericConversionTest.java b/tests/src/test/java/org/teavm/vm/NumericConversionTest.java new file mode 100644 index 000000000..0655a7399 --- /dev/null +++ b/tests/src/test/java/org/teavm/vm/NumericConversionTest.java @@ -0,0 +1,66 @@ +/* + * Copyright 2024 Alexey Andreev. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.teavm.vm; + +import static org.junit.Assert.assertEquals; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.teavm.junit.SkipPlatform; +import org.teavm.junit.TeaVMTestRunner; +import org.teavm.junit.TestPlatform; + +@RunWith(TeaVMTestRunner.class) +public class NumericConversionTest { + @Test + @SkipPlatform({TestPlatform.JAVASCRIPT, TestPlatform.C}) + public void floatOverflow() { + assertEquals(2147483647, (int) (floatOne() * (1 << 30) * (1 << 3))); + assertEquals(2147483647, (int) (floatOne() * Float.POSITIVE_INFINITY)); + assertEquals(-2147483648, (int) (-floatOne() * (1 << 30) * (1 << 3))); + assertEquals(-2147483648, (int) (-floatOne() * Float.POSITIVE_INFINITY)); + assertEquals(0, (int) (floatOne() * Float.NaN)); + + assertEquals((1L << 63) - 1, (long) (floatOne() * (1L << 60) * (1 << 5))); + assertEquals((1L << 63) - 1, (long) (floatOne() * Float.POSITIVE_INFINITY)); + assertEquals(1L << 63, (long) (-floatOne() * (1L << 60) * (1 << 5))); + assertEquals(1L << 63, (long) (-floatOne() * Float.POSITIVE_INFINITY)); + assertEquals(0, (long) (floatOne() * Float.NaN)); + } + + @Test + @SkipPlatform({TestPlatform.JAVASCRIPT, TestPlatform.C}) + public void doubleOverflow() { + assertEquals(2147483647, (int) (doubleOne() * (1 << 30) * (1 << 3))); + assertEquals(2147483647, (int) (doubleOne() * Float.POSITIVE_INFINITY)); + assertEquals(-2147483648, (int) (-doubleOne() * (1 << 30) * (1 << 3))); + assertEquals(-2147483648, (int) (-doubleOne() * Float.POSITIVE_INFINITY)); + assertEquals(0, (int) (doubleOne() * Double.NaN)); + + assertEquals((1L << 63) - 1, (long) (doubleOne() * (1L << 60) * (1 << 5))); + assertEquals((1L << 63) - 1, (long) (doubleOne() * Double.POSITIVE_INFINITY)); + assertEquals(1L << 63, (long) (-doubleOne() * (1L << 60) * (1 << 5))); + assertEquals(1L << 63, (long) (-doubleOne() * Double.POSITIVE_INFINITY)); + assertEquals(0, (long) (doubleOne() * Double.NaN)); + } + + private float floatOne() { + return 1; + } + + private double doubleOne() { + return 1; + } +}