wasm: use non-trapping conversion instructions when casting floats and doubles to ints and longs

Fix #976
This commit is contained in:
Alexey Andreev 2024-11-30 18:50:43 +01:00
parent 146083565c
commit f11a5474d0
7 changed files with 168 additions and 44 deletions

View File

@ -641,7 +641,8 @@ public class DisassemblyCodeListener extends BaseDisassemblyListener implements
} }
@Override @Override
public void convert(WasmNumType sourceType, WasmNumType targetType, boolean signed, boolean reinterpret) { public void convert(WasmNumType sourceType, WasmNumType targetType, boolean signed, boolean reinterpret,
boolean nonTrapping) {
switch (targetType) { switch (targetType) {
case INT32: case INT32:
writer.write("i32."); writer.write("i32.");
@ -649,18 +650,20 @@ public class DisassemblyCodeListener extends BaseDisassemblyListener implements
case FLOAT32: case FLOAT32:
if (reinterpret) { if (reinterpret) {
writer.write("reinterpret_f32"); writer.write("reinterpret_f32");
} else if (signed) {
writer.write("trunc_f32_s");
} else { } else {
writer.write("trunc_f32_u"); writer.write("trunc_");
if (nonTrapping) {
writer.write("sat_");
}
writer.write("f32_").write(signed ? "s" : "u");
} }
break; break;
case FLOAT64: case FLOAT64:
if (signed) { writer.write("trunc_");
writer.write("trunc_f64_s"); if (nonTrapping) {
} else { writer.write("sat_");
writer.write("trunc_f64_u");
} }
writer.write("f64_").write(signed ? "s" : "u");
break; break;
case INT64: case INT64:
writer.write("wrap_i64"); writer.write("wrap_i64");
@ -674,19 +677,21 @@ public class DisassemblyCodeListener extends BaseDisassemblyListener implements
writer.write("i64."); writer.write("i64.");
switch (sourceType) { switch (sourceType) {
case FLOAT32: case FLOAT32:
if (signed) { writer.write("trunc_");
writer.write("trunc_f32_s"); if (nonTrapping) {
} else { writer.write("sat_");
writer.write("trunc_f32_u");
} }
writer.write("f32_").write(signed ? "s" : "u");
break; break;
case FLOAT64: case FLOAT64:
if (reinterpret) { if (reinterpret) {
writer.write("reinterpret_f64"); writer.write("reinterpret_f64");
} else if (signed) {
writer.write("trunc_f64_s");
} else { } else {
writer.write("trunc_f64_u"); writer.write("trunc_");
if (nonTrapping) {
writer.write("sat_");
}
writer.write("f64_").write(signed ? "s" : "u");
} }
break; break;
case INT32: case INT32:

View File

@ -1172,9 +1172,11 @@ public abstract class BaseWasmGenerationVisitor implements StatementVisitor, Exp
@Override @Override
public void visit(PrimitiveCastExpr expr) { public void visit(PrimitiveCastExpr expr) {
accept(expr.getValue()); accept(expr.getValue());
result = new WasmConversion(WasmGeneratorUtil.mapType(expr.getSource()), var conversion = new WasmConversion(WasmGeneratorUtil.mapType(expr.getSource()),
WasmGeneratorUtil.mapType(expr.getTarget()), true, result); WasmGeneratorUtil.mapType(expr.getTarget()), true, result);
result.setLocation(expr.getLocation()); conversion.setNonTrapping(true);
conversion.setLocation(expr.getLocation());
result = conversion;
} }
@Override @Override

View File

@ -24,6 +24,7 @@ public class WasmConversion extends WasmExpression {
private boolean signed; private boolean signed;
private WasmExpression operand; private WasmExpression operand;
private boolean reinterpret; private boolean reinterpret;
private boolean nonTrapping;
public WasmConversion(WasmNumType sourceType, WasmNumType targetType, boolean signed, WasmExpression operand) { public WasmConversion(WasmNumType sourceType, WasmNumType targetType, boolean signed, WasmExpression operand) {
Objects.requireNonNull(sourceType); Objects.requireNonNull(sourceType);
@ -78,6 +79,14 @@ public class WasmConversion extends WasmExpression {
this.operand = operand; this.operand = operand;
} }
public boolean isNonTrapping() {
return nonTrapping;
}
public void setNonTrapping(boolean nonTrapping) {
this.nonTrapping = nonTrapping;
}
@Override @Override
public void acceptVisitor(WasmExpressionVisitor visitor) { public void acceptVisitor(WasmExpressionVisitor visitor) {
visitor.visit(this); visitor.visit(this);

View File

@ -112,7 +112,8 @@ public interface CodeListener {
default void storeFloat64(int align, int offset) { default void storeFloat64(int align, int offset) {
} }
default void convert(WasmNumType sourceType, WasmNumType targetType, boolean signed, boolean reinterpret) { default void convert(WasmNumType sourceType, WasmNumType targetType, boolean signed, boolean reinterpret,
boolean nonTrapping) {
} }
default void memoryGrow() { default void memoryGrow() {

View File

@ -516,76 +516,76 @@ public class CodeParser extends BaseSectionParser {
break; break;
case 0xA7: case 0xA7:
codeListener.convert(WasmNumType.INT64, WasmNumType.INT32, false, false); codeListener.convert(WasmNumType.INT64, WasmNumType.INT32, false, false, false);
break; break;
case 0xA8: case 0xA8:
codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT32, false, false); codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT32, false, false, false);
break; break;
case 0xA9: case 0xA9:
codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT32, true, false); codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT32, true, false, false);
break; break;
case 0xAA: case 0xAA:
codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT32, false, false); codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT32, false, false, false);
break; break;
case 0xAB: case 0xAB:
codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT32, true, false); codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT32, true, false, false);
break; break;
case 0xAC: case 0xAC:
codeListener.convert(WasmNumType.INT32, WasmNumType.INT64, false, false); codeListener.convert(WasmNumType.INT32, WasmNumType.INT64, false, false, false);
break; break;
case 0xAD: case 0xAD:
codeListener.convert(WasmNumType.INT32, WasmNumType.INT64, true, false); codeListener.convert(WasmNumType.INT32, WasmNumType.INT64, true, false, false);
break; break;
case 0xAE: case 0xAE:
codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT64, false, false); codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT64, false, false, false);
break; break;
case 0xAF: case 0xAF:
codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT64, true, false); codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT64, true, false, false);
break; break;
case 0xB0: case 0xB0:
codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT64, false, false); codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT64, false, false, false);
break; break;
case 0xB1: case 0xB1:
codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT64, true, false); codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT64, true, false, false);
break; break;
case 0xB2: case 0xB2:
codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT32, false, false); codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT32, false, false, false);
break; break;
case 0xB3: case 0xB3:
codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT32, true, false); codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT32, true, false, false);
break; break;
case 0xB4: case 0xB4:
codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT32, false, false); codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT32, false, false, false);
break; break;
case 0xB5: case 0xB5:
codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT32, true, false); codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT32, true, false, false);
break; break;
case 0xB6: case 0xB6:
codeListener.convert(WasmNumType.FLOAT64, WasmNumType.FLOAT32, true, false); codeListener.convert(WasmNumType.FLOAT64, WasmNumType.FLOAT32, true, false, false);
break; break;
case 0xB7: case 0xB7:
codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT64, false, false); codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT64, false, false, false);
break; break;
case 0xB8: case 0xB8:
codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT64, true, false); codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT64, true, false, false);
break; break;
case 0xB9: case 0xB9:
codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT64, false, false); codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT64, false, false, false);
break; break;
case 0xBA: case 0xBA:
codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT64, true, false); codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT64, true, false, false);
break; break;
case 0xBC: case 0xBC:
codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT32, false, true); codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT32, false, true, false);
break; break;
case 0xBD: case 0xBD:
codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT64, false, true); codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT64, false, true, false);
break; break;
case 0xBE: case 0xBE:
codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT32, false, true); codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT32, false, true, false);
break; break;
case 0xBF: case 0xBF:
codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT64, false, true); codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT64, false, true, false);
break; break;
case 0xD0: case 0xD0:
@ -623,6 +623,31 @@ public class CodeParser extends BaseSectionParser {
private boolean parseExtExpr() { private boolean parseExtExpr() {
switch (readLEB()) { switch (readLEB()) {
case 0:
codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT32, true, false, true);
return true;
case 1:
codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT32, false, false, true);
return true;
case 2:
codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT64, true, false, true);
return true;
case 3:
codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT64, false, false, true);
return true;
case 4:
codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT32, true, false, true);
return true;
case 5:
codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT32, false, false, true);
return true;
case 6:
codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT64, true, false, true);
return true;
case 7:
codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT64, false, false, true);
return true;
case 10: { case 10: {
if (reader.data[reader.ptr++] != 0 || reader.data[reader.ptr++] != 0) { if (reader.data[reader.ptr++] != 0 || reader.data[reader.ptr++] != 0) {
return false; return false;

View File

@ -839,12 +839,20 @@ class WasmBinaryRenderingVisitor implements WasmExpressionVisitor {
case INT32: case INT32:
if (expression.isReinterpret()) { if (expression.isReinterpret()) {
writer.writeByte(0xBC); writer.writeByte(0xBC);
} else if (expression.isNonTrapping()) {
writer.writeByte(0xFC);
writer.writeByte(expression.isSigned() ? 0 : 1);
} else { } else {
writer.writeByte(expression.isSigned() ? 0xA8 : 0xA9); writer.writeByte(expression.isSigned() ? 0xA8 : 0xA9);
} }
break; break;
case INT64: case INT64:
if (expression.isNonTrapping()) {
writer.writeByte(0xFC);
writer.writeByte(expression.isSigned() ? 4 : 5);
} else {
writer.writeByte(expression.isSigned() ? 0xAE : 0xAF); writer.writeByte(expression.isSigned() ? 0xAE : 0xAF);
}
break; break;
case FLOAT32: case FLOAT32:
break; break;
@ -856,11 +864,19 @@ class WasmBinaryRenderingVisitor implements WasmExpressionVisitor {
case FLOAT64: case FLOAT64:
switch (expression.getTargetType()) { switch (expression.getTargetType()) {
case INT32: case INT32:
if (expression.isNonTrapping()) {
writer.writeByte(0xFC);
writer.writeByte(expression.isSigned() ? 2 : 3);
} else {
writer.writeByte(expression.isSigned() ? 0xAA : 0xAB); writer.writeByte(expression.isSigned() ? 0xAA : 0xAB);
}
break; break;
case INT64: case INT64:
if (expression.isReinterpret()) { if (expression.isReinterpret()) {
writer.writeByte(0xBD); writer.writeByte(0xBD);
} else if (expression.isNonTrapping()) {
writer.writeByte(0xFC);
writer.writeByte(expression.isSigned() ? 6 : 7);
} else { } else {
writer.writeByte(expression.isSigned() ? 0xB0 : 0xB1); writer.writeByte(expression.isSigned() ? 0xB0 : 0xB1);
} }

View File

@ -0,0 +1,66 @@
/*
* Copyright 2024 Alexey Andreev.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.teavm.vm;
import static org.junit.Assert.assertEquals;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.teavm.junit.SkipPlatform;
import org.teavm.junit.TeaVMTestRunner;
import org.teavm.junit.TestPlatform;
@RunWith(TeaVMTestRunner.class)
public class NumericConversionTest {
@Test
@SkipPlatform({TestPlatform.JAVASCRIPT, TestPlatform.C})
public void floatOverflow() {
assertEquals(2147483647, (int) (floatOne() * (1 << 30) * (1 << 3)));
assertEquals(2147483647, (int) (floatOne() * Float.POSITIVE_INFINITY));
assertEquals(-2147483648, (int) (-floatOne() * (1 << 30) * (1 << 3)));
assertEquals(-2147483648, (int) (-floatOne() * Float.POSITIVE_INFINITY));
assertEquals(0, (int) (floatOne() * Float.NaN));
assertEquals((1L << 63) - 1, (long) (floatOne() * (1L << 60) * (1 << 5)));
assertEquals((1L << 63) - 1, (long) (floatOne() * Float.POSITIVE_INFINITY));
assertEquals(1L << 63, (long) (-floatOne() * (1L << 60) * (1 << 5)));
assertEquals(1L << 63, (long) (-floatOne() * Float.POSITIVE_INFINITY));
assertEquals(0, (long) (floatOne() * Float.NaN));
}
@Test
@SkipPlatform({TestPlatform.JAVASCRIPT, TestPlatform.C})
public void doubleOverflow() {
assertEquals(2147483647, (int) (doubleOne() * (1 << 30) * (1 << 3)));
assertEquals(2147483647, (int) (doubleOne() * Float.POSITIVE_INFINITY));
assertEquals(-2147483648, (int) (-doubleOne() * (1 << 30) * (1 << 3)));
assertEquals(-2147483648, (int) (-doubleOne() * Float.POSITIVE_INFINITY));
assertEquals(0, (int) (doubleOne() * Double.NaN));
assertEquals((1L << 63) - 1, (long) (doubleOne() * (1L << 60) * (1 << 5)));
assertEquals((1L << 63) - 1, (long) (doubleOne() * Double.POSITIVE_INFINITY));
assertEquals(1L << 63, (long) (-doubleOne() * (1L << 60) * (1 << 5)));
assertEquals(1L << 63, (long) (-doubleOne() * Double.POSITIVE_INFINITY));
assertEquals(0, (long) (doubleOne() * Double.NaN));
}
private float floatOne() {
return 1;
}
private double doubleOne() {
return 1;
}
}