From 697ad7376273339c44c67a1d0130b3b76ac368f5 Mon Sep 17 00:00:00 2001
From: Ivan Hetman <Ihromant@users.noreply.github.com>
Date: Tue, 16 May 2023 10:57:13 +0300
Subject: [PATCH] classlib: add Collectors grouping by

---
 .../teavm/classlib/java/util/THashMap.java    | 18 +++++++
 .../org/teavm/classlib/java/util/TMap.java    |  8 +++
 .../java/util/stream/TCollectors.java         | 49 +++++++++++++++++++
 .../org/teavm/classlib/java/util/MapTest.java | 13 +++++
 .../java/util/stream/CollectorsTest.java      | 12 +++++
 5 files changed, 100 insertions(+)

diff --git a/classlib/src/main/java/org/teavm/classlib/java/util/THashMap.java b/classlib/src/main/java/org/teavm/classlib/java/util/THashMap.java
index 1fcf6e248..3d1794518 100644
--- a/classlib/src/main/java/org/teavm/classlib/java/util/THashMap.java
+++ b/classlib/src/main/java/org/teavm/classlib/java/util/THashMap.java
@@ -34,6 +34,7 @@ package org.teavm.classlib.java.util;
 
 import java.util.Arrays;
 import java.util.function.BiConsumer;
+import java.util.function.BiFunction;
 import java.util.function.Consumer;
 import org.teavm.classlib.java.io.TSerializable;
 import org.teavm.classlib.java.lang.*;
@@ -668,6 +669,23 @@ public class THashMap<K, V> extends TAbstractMap<K, V> implements TCloneable, TS
         }
     }
 
+    @Override
+    public void replaceAll(BiFunction<? super K, ? super V, ? extends V> function) {
+        if (elementCount > 0) {
+            int prevModCount = modCount;
+            for (int i = 0; i < elementData.length; i++) {
+                HashEntry<K, V> entry = elementData[i];
+                while (entry != null) {
+                    entry.value = function.apply(entry.key, entry.value);
+                    entry = entry.next;
+                    if (prevModCount != modCount) {
+                        throw new TConcurrentModificationException();
+                    }
+                }
+            }
+        }
+    }
+
     static int computeHashCode(Object key) {
         return key.hashCode();
     }
diff --git a/classlib/src/main/java/org/teavm/classlib/java/util/TMap.java b/classlib/src/main/java/org/teavm/classlib/java/util/TMap.java
index 4d58b55cc..f5a5fd7d8 100644
--- a/classlib/src/main/java/org/teavm/classlib/java/util/TMap.java
+++ b/classlib/src/main/java/org/teavm/classlib/java/util/TMap.java
@@ -170,6 +170,14 @@ public interface TMap<K, V> {
         }
     }
 
+    default void replaceAll(BiFunction<? super K, ? super V, ? extends V> function) {
+        TIterator<Entry<K, V>> iterator = entrySet().iterator();
+        while (iterator.hasNext()) {
+            TMap.Entry<K, V> next = iterator.next();
+            next.setValue(function.apply(next.getKey(), next.getValue()));
+        }
+    }
+
     static <K, V> TMap<K, V> of() {
         return TCollections.emptyMap();
     }
diff --git a/classlib/src/main/java/org/teavm/classlib/java/util/stream/TCollectors.java b/classlib/src/main/java/org/teavm/classlib/java/util/stream/TCollectors.java
index 3056dd5e9..0bda46df3 100644
--- a/classlib/src/main/java/org/teavm/classlib/java/util/stream/TCollectors.java
+++ b/classlib/src/main/java/org/teavm/classlib/java/util/stream/TCollectors.java
@@ -124,6 +124,55 @@ public final class TCollectors {
                 TCollector.Characteristics.IDENTITY_FINISH);
     }
 
+    public static <E, K> TCollector<E, ?, Map<K, List<E>>> groupingBy(Function<? super E, ? extends K> keyExtractor) {
+        return groupingBy(keyExtractor, toList());
+    }
+
+    public static <E, K, V, I> TCollector<E, ?, Map<K, V>> groupingBy(
+            Function<? super E, ? extends K> keyExtractor,
+            TCollector<? super E, I, V> downstream) {
+        return groupingBy(keyExtractor, HashMap::new, downstream);
+    }
+
+    public static <E, K, V, I, M extends Map<K, V>> TCollector<E, ?, M> groupingBy(
+            Function<? super E, ? extends K> keyExtractor,
+            Supplier<M> mapFactory,
+            TCollector<? super E, I, V> downstream) {
+        BiConsumer<Map<K, I>, E> mapAppender = (m, t) -> {
+            K key = keyExtractor.apply(t);
+            I container = m.computeIfAbsent(key, k -> downstream.supplier().get());
+            downstream.accumulator().accept(container, t);
+        };
+        BinaryOperator<Map<K, I>> mapMerger = (m1, m2) -> {
+            for (Map.Entry<K, I> e : m2.entrySet()) {
+                m1.merge(e.getKey(), e.getValue(), downstream.combiner());
+            }
+            return m1;
+        };
+
+        if (downstream.characteristics().contains(TCollector.Characteristics.IDENTITY_FINISH)) {
+            return TCollector.of(castFactory(mapFactory), mapAppender, mapMerger,
+                    castFunction(Function.identity()), TCollector.Characteristics.IDENTITY_FINISH);
+        } else {
+            Function<I, I> replacer = castFunction(downstream.finisher());
+            Function<Map<K, I>, M> finisher = toReplace -> {
+                toReplace.replaceAll((k, v) -> replacer.apply(v));
+                return (M) toReplace;
+            };
+            return TCollector.of(castFactory(mapFactory), mapAppender, mapMerger, finisher);
+        }
+    }
+
+    @SuppressWarnings("unchecked")
+    private static <A, C> Supplier<A> castFactory(Supplier<C> supp) {
+        return (Supplier<A>) supp;
+    }
+
+    @SuppressWarnings("unchecked")
+    private static <A, B, C, D> Function<A, B> castFunction(Function<C, D> func) {
+        return (Function<A, B>) func;
+    }
+
     public static <T, A, R, K> TCollector<T, A, K> collectingAndThen(
             TCollector<T, A, R> downstream,
             Function<R, K> finisher) {
diff --git a/tests/src/test/java/org/teavm/classlib/java/util/MapTest.java b/tests/src/test/java/org/teavm/classlib/java/util/MapTest.java
index 168017f2f..8eea9b1b0 100644
--- a/tests/src/test/java/org/teavm/classlib/java/util/MapTest.java
+++ b/tests/src/test/java/org/teavm/classlib/java/util/MapTest.java
@@ -20,7 +20,9 @@ import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
+import java.util.HashMap;
 import java.util.Map;
+import java.util.TreeMap;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.teavm.junit.TeaVMTestRunner;
@@ -122,4 +124,15 @@ public class MapTest {
             assertNull("Iterator did not return all of expected elements", e);
         }
     }
+
+    @Test
+    public void testReplaceAll() {
+        Map<String, Integer> base = Map.of("a", 1, "b", 2);
+        Map<String, Integer> hashMap = new HashMap<>(base);
+        Map<String, Integer> treeMap = new TreeMap<>(base);
+        hashMap.replaceAll((k, v) -> v * 10);
+        treeMap.replaceAll((k, v) -> v * 10);
+        assertEquals(Map.of("a", 10, "b", 20), hashMap);
+        assertEquals(Map.of("a", 10, "b", 20), treeMap);
+    }
 }
diff --git a/tests/src/test/java/org/teavm/classlib/java/util/stream/CollectorsTest.java b/tests/src/test/java/org/teavm/classlib/java/util/stream/CollectorsTest.java
index 64c36d9cf..c249ce4d0 100644
--- a/tests/src/test/java/org/teavm/classlib/java/util/stream/CollectorsTest.java
+++ b/tests/src/test/java/org/teavm/classlib/java/util/stream/CollectorsTest.java
@@ -78,4 +78,16 @@ public class CollectorsTest {
         assertEquals(expected,
                 IntStream.range(1, 4).boxed().collect(Collectors.toMap(Function.identity(), Function.identity())));
     }
+
+    @Test
+    public void groupingBy() {
+        List<Integer> numbers = List.of(1, 2, 2, 3, 3, 3, 4, 4, 4, 4);
+        assertEquals(Map.of(1, List.of(1), 2, List.of(2, 2),
+                        3, List.of(3, 3, 3), 4, List.of(4, 4, 4, 4)),
+                numbers.stream().collect(Collectors.groupingBy(Function.identity())));
+        assertEquals(Map.of(1, 1, 2, 4, 3, 9, 4, 16),
+                numbers.stream().collect(Collectors.groupingBy(Function.identity(),
+                        Collectors.collectingAndThen(Collectors.toList(),
+                                l -> l.stream().mapToInt(i -> i).sum()))));
+    }
 }