From e7c28da4e28b052ad1ee694e519c50d6787be2c9 Mon Sep 17 00:00:00 2001 From: Alexey Andreev Date: Tue, 16 Apr 2024 18:40:44 +0200 Subject: [PATCH] classlib: fix ThreadLocal implementation for multi-thread case --- .../org/teavm/classlib/java/lang/TThread.java | 3 + .../classlib/java/lang/TThreadLocal.java | 79 ++++++++++++++-- .../classlib/java/lang/ThreadLocalTest.java | 92 +++++++++++++++++++ 3 files changed, 166 insertions(+), 8 deletions(-) create mode 100644 tests/src/test/java/org/teavm/classlib/java/lang/ThreadLocalTest.java diff --git a/classlib/src/main/java/org/teavm/classlib/java/lang/TThread.java b/classlib/src/main/java/org/teavm/classlib/java/lang/TThread.java index b6bc09e6d..e8e58592a 100644 --- a/classlib/src/main/java/org/teavm/classlib/java/lang/TThread.java +++ b/classlib/src/main/java/org/teavm/classlib/java/lang/TThread.java @@ -38,6 +38,7 @@ public class TThread extends TObject implements TRunnable { private final Object finishedLock = new Object(); private boolean interruptedFlag; public TThreadInterruptHandler interruptHandler; + public Object key; private String name; private boolean alive = true; @@ -77,6 +78,7 @@ public class TThread extends TObject implements TRunnable { try { activeCount++; setCurrentThread(TThread.this); + key = new Object(); TThread.this.run(); } catch (Throwable t) { getUncaughtExceptionHandler().uncaughtException(this, t); @@ -86,6 +88,7 @@ public class TThread extends TObject implements TRunnable { } alive = false; activeCount--; + key = null; setCurrentThread(mainThread); } } diff --git a/classlib/src/main/java/org/teavm/classlib/java/lang/TThreadLocal.java b/classlib/src/main/java/org/teavm/classlib/java/lang/TThreadLocal.java index 061f987a6..b3b02e7da 100644 --- a/classlib/src/main/java/org/teavm/classlib/java/lang/TThreadLocal.java +++ b/classlib/src/main/java/org/teavm/classlib/java/lang/TThreadLocal.java @@ -15,7 +15,12 @@ */ package org.teavm.classlib.java.lang; +import java.util.Map; +import java.util.WeakHashMap; +import java.util.function.Supplier; + public class TThreadLocal extends TObject { + private Map map; private boolean initialized; private T value; @@ -27,21 +32,79 @@ public class TThreadLocal extends TObject { return null; } + @SuppressWarnings("unchecked") public T get() { - if (!initialized) { - value = initialValue(); - initialized = true; + if (isInMainThread()) { + if (!initialized) { + value = initialValue(); + initialized = true; + } + cleanupMap(); + return value; + } else { + var key = TThread.currentThread().key; + initMap(); + var value = map.get(key); + if (value == null) { + value = initialValue(); + map.put(key, value == null ? NULL : value); + } else if (value == NULL) { + value = null; + } + cleanupMap(); + return (T) value; } - return value; } public void set(T value) { - initialized = true; - this.value = value; + if (isInMainThread()) { + initialized = true; + this.value = value; + cleanupMap(); + } else { + initMap(); + map.put(TThread.currentThread().key, value == null ? NULL : value); + cleanupMap(); + } } public void remove() { - initialized = false; - value = null; + if (isInMainThread()) { + initialized = false; + value = null; + cleanupMap(); + } else { + if (map != null) { + map.remove(TThread.currentThread().key); + cleanupMap(); + } + } } + + private void initMap() { + if (map == null) { + map = new WeakHashMap<>(); + } + } + + private void cleanupMap() { + if (map != null && map.isEmpty()) { + map = null; + } + } + + public static TThreadLocal withInitial(Supplier supplier) { + return new TThreadLocal<>() { + @Override + protected S initialValue() { + return supplier.get(); + } + }; + } + + private static boolean isInMainThread() { + return TThread.currentThread() == TThread.getMainThread(); + } + + private static final Object NULL = new Object(); } diff --git a/tests/src/test/java/org/teavm/classlib/java/lang/ThreadLocalTest.java b/tests/src/test/java/org/teavm/classlib/java/lang/ThreadLocalTest.java new file mode 100644 index 000000000..06d1d3561 --- /dev/null +++ b/tests/src/test/java/org/teavm/classlib/java/lang/ThreadLocalTest.java @@ -0,0 +1,92 @@ +/* + * Copyright 2024 Alexey Andreev. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.teavm.classlib.java.lang; + +import static org.junit.Assert.assertEquals; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.teavm.classlib.PlatformDetector; +import org.teavm.interop.Async; +import org.teavm.interop.AsyncCallback; +import org.teavm.jso.browser.Window; +import org.teavm.junit.OnlyPlatform; +import org.teavm.junit.TeaVMTestRunner; +import org.teavm.junit.TestPlatform; + +@RunWith(TeaVMTestRunner.class) +public class ThreadLocalTest { + private volatile int counter; + + @Test + @OnlyPlatform(TestPlatform.JAVASCRIPT) + public void concurrentUpdate() throws InterruptedException { + var local = new ThreadLocal(); + var monitor = new Object(); + var results = new String[5]; + counter = results.length; + + for (var n = 0; n < results.length; ++n) { + var threadIndex = n; + new Thread(() -> { + var prefix = Character.toString((char) ('a' + threadIndex)); + for (var i = 0; i < 10; ++i) { + var old = local.get(); + if (old == null) { + old = ""; + } + local.set(old + prefix + i); + sleep(); + } + synchronized (monitor) { + results[threadIndex] = local.get(); + --counter; + monitor.notifyAll(); + } + }).start(); + } + + do { + synchronized (monitor) { + monitor.wait(); + } + } while (counter > 0); + + assertEquals("a0a1a2a3a4a5a6a7a8a9", results[0]); + assertEquals("b0b1b2b3b4b5b6b7b8b9", results[1]); + assertEquals("c0c1c2c3c4c5c6c7c8c9", results[2]); + assertEquals("d0d1d2d3d4d5d6d7d8d9", results[3]); + assertEquals("e0e1e2e3e4e5e6e7e8e9", results[4]); + } + + private void sleep() { + if (!PlatformDetector.isJavaScript()) { + try { + Thread.sleep(50); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } else { + sleepInJs(); + } + } + + @Async + private static native void sleepInJs(); + + private static void sleepInJs(AsyncCallback callback) { + Window.setTimeout(() -> callback.complete(null), 0); + } +}