classlib: fix ThreadLocal implementation for multi-thread case

This commit is contained in:
Alexey Andreev 2024-04-16 18:40:44 +02:00
parent 931f0f1f4a
commit e7c28da4e2
3 changed files with 166 additions and 8 deletions

View File

@ -38,6 +38,7 @@ public class TThread extends TObject implements TRunnable {
private final Object finishedLock = new Object();
private boolean interruptedFlag;
public TThreadInterruptHandler interruptHandler;
public Object key;
private String name;
private boolean alive = true;
@ -77,6 +78,7 @@ public class TThread extends TObject implements TRunnable {
try {
activeCount++;
setCurrentThread(TThread.this);
key = new Object();
TThread.this.run();
} catch (Throwable t) {
getUncaughtExceptionHandler().uncaughtException(this, t);
@ -86,6 +88,7 @@ public class TThread extends TObject implements TRunnable {
}
alive = false;
activeCount--;
key = null;
setCurrentThread(mainThread);
}
}

View File

@ -15,7 +15,12 @@
*/
package org.teavm.classlib.java.lang;
import java.util.Map;
import java.util.WeakHashMap;
import java.util.function.Supplier;
public class TThreadLocal<T> extends TObject {
private Map<Object, Object> map;
private boolean initialized;
private T value;
@ -27,21 +32,79 @@ public class TThreadLocal<T> extends TObject {
return null;
}
@SuppressWarnings("unchecked")
public T get() {
if (isInMainThread()) {
if (!initialized) {
value = initialValue();
initialized = true;
}
cleanupMap();
return value;
} else {
var key = TThread.currentThread().key;
initMap();
var value = map.get(key);
if (value == null) {
value = initialValue();
map.put(key, value == null ? NULL : value);
} else if (value == NULL) {
value = null;
}
cleanupMap();
return (T) value;
}
}
public void set(T value) {
if (isInMainThread()) {
initialized = true;
this.value = value;
cleanupMap();
} else {
initMap();
map.put(TThread.currentThread().key, value == null ? NULL : value);
cleanupMap();
}
}
public void remove() {
if (isInMainThread()) {
initialized = false;
value = null;
cleanupMap();
} else {
if (map != null) {
map.remove(TThread.currentThread().key);
cleanupMap();
}
}
}
private void initMap() {
if (map == null) {
map = new WeakHashMap<>();
}
}
private void cleanupMap() {
if (map != null && map.isEmpty()) {
map = null;
}
}
public static <S> TThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
return new TThreadLocal<>() {
@Override
protected S initialValue() {
return supplier.get();
}
};
}
private static boolean isInMainThread() {
return TThread.currentThread() == TThread.getMainThread();
}
private static final Object NULL = new Object();
}

View File

@ -0,0 +1,92 @@
/*
* Copyright 2024 Alexey Andreev.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.teavm.classlib.java.lang;
import static org.junit.Assert.assertEquals;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.teavm.classlib.PlatformDetector;
import org.teavm.interop.Async;
import org.teavm.interop.AsyncCallback;
import org.teavm.jso.browser.Window;
import org.teavm.junit.OnlyPlatform;
import org.teavm.junit.TeaVMTestRunner;
import org.teavm.junit.TestPlatform;
@RunWith(TeaVMTestRunner.class)
public class ThreadLocalTest {
private volatile int counter;
@Test
@OnlyPlatform(TestPlatform.JAVASCRIPT)
public void concurrentUpdate() throws InterruptedException {
var local = new ThreadLocal<String>();
var monitor = new Object();
var results = new String[5];
counter = results.length;
for (var n = 0; n < results.length; ++n) {
var threadIndex = n;
new Thread(() -> {
var prefix = Character.toString((char) ('a' + threadIndex));
for (var i = 0; i < 10; ++i) {
var old = local.get();
if (old == null) {
old = "";
}
local.set(old + prefix + i);
sleep();
}
synchronized (monitor) {
results[threadIndex] = local.get();
--counter;
monitor.notifyAll();
}
}).start();
}
do {
synchronized (monitor) {
monitor.wait();
}
} while (counter > 0);
assertEquals("a0a1a2a3a4a5a6a7a8a9", results[0]);
assertEquals("b0b1b2b3b4b5b6b7b8b9", results[1]);
assertEquals("c0c1c2c3c4c5c6c7c8c9", results[2]);
assertEquals("d0d1d2d3d4d5d6d7d8d9", results[3]);
assertEquals("e0e1e2e3e4e5e6e7e8e9", results[4]);
}
private void sleep() {
if (!PlatformDetector.isJavaScript()) {
try {
Thread.sleep(50);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
} else {
sleepInJs();
}
}
@Async
private static native void sleepInJs();
private static void sleepInJs(AsyncCallback<Void> callback) {
Window.setTimeout(() -> callback.complete(null), 0);
}
}