/*
 * Decompiled with CFR 0.152.
 */
package org.apache.myfaces.push.cdi;

import jakarta.annotation.PostConstruct;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.faces.context.ExternalContext;
import jakarta.websocket.CloseReason;
import jakarta.websocket.Session;
import java.io.IOException;
import java.io.Serializable;
import java.lang.ref.Reference;
import java.lang.ref.SoftReference;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Future;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import org.apache.myfaces.config.webparameters.MyfacesConfig;
import org.apache.myfaces.push.Json;
import org.apache.myfaces.push.WebsocketSessionClusterSerializedRestore;
import org.apache.myfaces.util.lang.ConcurrentLRUCache;
import org.apache.myfaces.util.lang.Lazy;

@ApplicationScoped
public class WebsocketSessionManager {
    private Lazy<ConcurrentLRUCache<String, Collection<Reference<Session>>>> sessionMap;
    private Lazy<ConcurrentHashMap<UserChannelKey, ConcurrentMap<String, Integer>>> userMap;
    private Queue<String> restoreQueue;
    private static final CloseReason REASON_EXPIRED = new CloseReason((CloseReason.CloseCode)CloseReason.CloseCodes.NORMAL_CLOSURE, "Expired");
    private static final Logger LOG = Logger.getLogger(WebsocketSessionManager.class.getName());
    private final String WARNING_TOMCAT_WEB_SOCKET_BOMBED = "Tomcat cannot handle concurrent push messages. A push message has been sent only after %s retries. Consider rate limiting sending push messages. For example, once every 500ms.";

    @PostConstruct
    public void init() {
        this.sessionMap = new Lazy<ConcurrentLRUCache>(() -> {
            int size = MyfacesConfig.WEBSOCKET_MAX_CONNECTIONS_DEFAULT;
            return new ConcurrentLRUCache((size * 4 + 3) / 3, size);
        });
        this.restoreQueue = new ConcurrentLinkedQueue<String>();
        this.userMap = new Lazy<ConcurrentHashMap>(ConcurrentHashMap::new);
    }

    public ConcurrentLRUCache<String, Collection<Reference<Session>>> getSessionMap() {
        return this.sessionMap.get();
    }

    public ConcurrentMap<UserChannelKey, ConcurrentMap<String, Integer>> getUserMap() {
        return this.userMap.get();
    }

    public void registerSessionToken(String channelToken) {
        ConcurrentLRUCache<String, Collection<Reference<Session>>> sessionMap = this.getSessionMap();
        if (sessionMap.get(channelToken) == null) {
            sessionMap.put(channelToken, new ConcurrentLinkedQueue());
        }
    }

    public void registerUser(Serializable user, String channel, String channelToken) {
        UserChannelKey userChannelKey = new UserChannelKey(user, channel);
        ConcurrentMap channelTokenMap = this.getUserMap().computeIfAbsent(userChannelKey, k -> new ConcurrentHashMap(1));
        channelTokenMap.compute(channelToken, (key, value) -> {
            int n;
            if (value == null) {
                n = 1;
            } else {
                value = value + 1;
                n = value;
            }
            return n;
        });
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void deregisterUser(Serializable user, String channel, String channelToken) {
        UserChannelKey userChannelKey = new UserChannelKey(user, channel);
        ConcurrentMap<UserChannelKey, ConcurrentMap<String, Integer>> concurrentMap = this.getUserMap();
        synchronized (concurrentMap) {
            ConcurrentMap channelTokenMap = (ConcurrentMap)this.getUserMap().get(userChannelKey);
            if (channelTokenMap != null) {
                int value = (Integer)channelTokenMap.get(channelToken);
                if (value == 1) {
                    channelTokenMap.remove(channelToken);
                } else {
                    channelTokenMap.put(channelToken, --value);
                }
                if (channelTokenMap.isEmpty()) {
                    this.getUserMap().remove(userChannelKey);
                }
            }
        }
    }

    public Set<String> getChannelTokensForUser(Serializable user, String channel) {
        UserChannelKey userChannelKey = new UserChannelKey(user, channel);
        ConcurrentMap channelTokenMap = (ConcurrentMap)this.getUserMap().get(userChannelKey);
        if (channelTokenMap == null) {
            return null;
        }
        return channelTokenMap.keySet();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void initSessionMap(ExternalContext context) {
        int size = MyfacesConfig.getCurrentInstance(context).getWebsocketMaxConnections();
        ConcurrentLRUCache<String, Collection> newSessionMap = new ConcurrentLRUCache<String, Collection>((size * 4 + 3) / 3, size);
        Lazy<ConcurrentLRUCache<String, Collection<Reference<Session>>>> lazy = this.sessionMap;
        synchronized (lazy) {
            if (this.sessionMap.isInitialized()) {
                Set<Map.Entry<String, Collection<Reference<Session>>>> entries = this.sessionMap.get().getLatestAccessedItems(MyfacesConfig.WEBSOCKET_MAX_CONNECTIONS_DEFAULT).entrySet();
                for (Map.Entry<String, Collection<Reference<Session>>> entry : entries) {
                    Collection<Reference<Session>> referenceCollection = entry.getValue();
                    if (referenceCollection == null) continue;
                    Collection newReferenceCollection = referenceCollection.stream().filter(p -> p.get() != null && ((Session)p.get()).isOpen()).distinct().collect(Collectors.toCollection(ConcurrentLinkedQueue::new));
                    newSessionMap.put(entry.getKey(), newReferenceCollection);
                }
            }
            this.sessionMap.reset(newSessionMap);
        }
    }

    public void clearSessions() {
        if (this.sessionMap.isInitialized()) {
            this.sessionMap.get().clear();
        }
        this.restoreQueue.clear();
    }

    public boolean addOrUpdateSession(String channelToken, Session session) {
        Optional<Reference> referenceOptional;
        ConcurrentLRUCache<String, Collection<Reference<Session>>> sessionMap;
        Collection<Reference<Session>> sessions;
        if (LOG.isLoggable(Level.FINE)) {
            LOG.log(Level.FINE, "WebsocketSessionManager: addOrUpdateSession for channelToken = {0}, session.id = {1}", new Object[]{channelToken, session.getId()});
        }
        if ((sessions = (sessionMap = this.getSessionMap()).get(channelToken)) == null) {
            this.registerSessionToken(channelToken);
            sessions = sessionMap.get(channelToken);
        }
        if (!(referenceOptional = sessions.stream().filter(p -> Objects.equals(p.get(), session)).findFirst()).isPresent()) {
            return sessions.add(new SoftReference<Session>(session));
        }
        return true;
    }

    public void removeSession(String channelToken, Session session) {
        if (LOG.isLoggable(Level.FINE)) {
            LOG.log(Level.FINE, "WebsocketSessionManager: removeSession for channelToken = {0}, session.id = {1}", new Object[]{channelToken, session.getId()});
        }
        Collection<Reference<Session>> collection = this.getSessionMap().get(channelToken);
        Optional<Reference> referenceOptional = collection.stream().filter(p -> Objects.equals(p.get(), session)).findFirst();
        referenceOptional.ifPresent(collection::remove);
    }

    public void removeChannelToken(String channelToken) {
        Collection<Reference<Session>> sessions = this.getSessionMap().get(channelToken);
        if (sessions != null) {
            for (Reference<Session> sessionReference : sessions) {
                Session session = sessionReference.get();
                if (session == null || !session.isOpen()) continue;
                try {
                    session.close(REASON_EXPIRED);
                }
                catch (IOException iOException) {}
            }
        }
        this.getSessionMap().remove(channelToken);
    }

    protected Set<Future<Void>> send(String channelToken, Object message) {
        Collection<Reference<Session>> sessions;
        this.synchronizeSessionInstances();
        HashSet<Future<Void>> results = new HashSet<Future<Void>>(1);
        Collection<Reference<Session>> collection = sessions = channelToken != null ? this.getSessionMap().get(channelToken) : null;
        if (sessions != null && !sessions.isEmpty()) {
            String json = Json.encode(message);
            sessions.forEach(sessionRef -> {
                if (sessionRef != null && sessionRef.get() != null) {
                    Session session = (Session)sessionRef.get();
                    if (session.isOpen()) {
                        this.send(session, json, results, 0);
                    } else {
                        this.removeSession(channelToken, session);
                    }
                }
            });
            return results;
        }
        return Collections.emptySet();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void send(Session session, String text, Set<Future<Void>> results, int retries) {
        try {
            results.add(session.getAsyncRemote().sendText(text));
            if (retries > 0) {
                LOG.warning(String.format("Tomcat cannot handle concurrent push messages. A push message has been sent only after %s retries. Consider rate limiting sending push messages. For example, once every 500ms.", retries));
            }
        }
        catch (IllegalStateException e) {
            if (this.isTomcatWebSocketBombed(session, e)) {
                Session session2 = session;
                synchronized (session2) {
                    this.send(session, text, results, retries + 1);
                }
            }
            throw e;
        }
    }

    private boolean isTomcatWebSocketBombed(Session session, IllegalStateException illegalStateException) {
        return session.getClass().getName().startsWith("org.apache.tomcat.websocket.") && illegalStateException.getMessage().contains("[TEXT_FULL_WRITING]");
    }

    public void synchronizeSessionInstances() {
        Map<String, Collection<Reference<Session>>> map;
        Queue<String> queue = this.getRestoredQueue();
        if (!queue.isEmpty() && (map = this.getSessionMap().getLatestAccessedItems(1)) != null && !map.isEmpty()) {
            Collection<Reference<Session>> collectionRef = map.values().iterator().next();
            collectionRef.stream().filter(ref -> ref != null).forEach(ref -> {
                Session session = (Session)ref.get();
                if (session != null) {
                    Set sessions = session.getOpenSessions();
                    for (Session instance : sessions) {
                        WebsocketSessionClusterSerializedRestore r = (WebsocketSessionClusterSerializedRestore)instance.getUserProperties().get("oam.websocket.SR");
                        if (r == null || !r.isDeserialized()) continue;
                        this.addOrUpdateSession(r.getChannelToken(), session);
                    }
                    queue.poll();
                }
            });
        }
    }

    public Queue<String> getRestoredQueue() {
        return this.restoreQueue;
    }

    public class UserChannelKey
    implements Serializable {
        private final Serializable user;
        private final String channel;

        public UserChannelKey(Serializable user, String channel) {
            this.user = user;
            this.channel = channel;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            UserChannelKey that = (UserChannelKey)o;
            return Objects.equals(this.user, that.user) && Objects.equals(this.channel, that.channel);
        }

        public int hashCode() {
            return Objects.hash(this.user, this.channel);
        }
    }
}

