/*
 * Decompiled with CFR 0.152.
 */
package org.apache.shenyu.plugin.ai.token.limiter;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.WritableByteChannel;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.zip.GZIPInputStream;
import org.apache.commons.lang3.StringUtils;
import org.apache.shenyu.common.dto.RuleData;
import org.apache.shenyu.common.dto.SelectorData;
import org.apache.shenyu.common.dto.convert.rule.AiTokenLimiterHandle;
import org.apache.shenyu.common.enums.AiTokenLimiterEnum;
import org.apache.shenyu.common.enums.PluginEnum;
import org.apache.shenyu.plugin.ai.common.strategy.AiModel;
import org.apache.shenyu.plugin.ai.token.limiter.handler.AiTokenLimiterPluginHandler;
import org.apache.shenyu.plugin.api.ShenyuPluginChain;
import org.apache.shenyu.plugin.api.result.ShenyuResultEnum;
import org.apache.shenyu.plugin.api.result.ShenyuResultWrap;
import org.apache.shenyu.plugin.api.utils.WebFluxResultUtils;
import org.apache.shenyu.plugin.base.AbstractShenyuPlugin;
import org.apache.shenyu.plugin.base.utils.CacheKeyUtils;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.data.redis.core.ReactiveRedisTemplate;
import org.springframework.http.HttpCookie;
import org.springframework.http.HttpStatus;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
import org.springframework.util.Assert;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.annotation.NonNull;

public class AiTokenLimiterPlugin
extends AbstractShenyuPlugin {
    private static final Logger LOG = LoggerFactory.getLogger(AiTokenLimiterPlugin.class);
    private static final String REDIS_KEY_PREFIX = "SHENYU:AI:TOKENLIMIT:";

    protected Mono<Void> doExecute(ServerWebExchange exchange, ShenyuPluginChain chain, SelectorData selector, RuleData rule) {
        AiTokenLimiterHandle aiTokenLimiterHandle = (AiTokenLimiterHandle)AiTokenLimiterPluginHandler.CACHED_HANDLE.get().obtainHandle((Object)CacheKeyUtils.INST.getKey(rule));
        if (Objects.isNull(aiTokenLimiterHandle)) {
            return chain.execute(exchange);
        }
        ReactiveRedisTemplate reactiveRedisTemplate = (ReactiveRedisTemplate)AiTokenLimiterPluginHandler.REDIS_CACHED_HANDLE.get().obtainHandle((Object)PluginEnum.AI_TOKEN_LIMITER.getName());
        Assert.notNull((Object)reactiveRedisTemplate, (String)"reactiveRedisTemplate is null");
        String tokenLimitType = aiTokenLimiterHandle.getAiTokenLimitType();
        String keyName = aiTokenLimiterHandle.getKeyName();
        Long tokenLimit = aiTokenLimiterHandle.getTokenLimit();
        Long timeWindowSeconds = aiTokenLimiterHandle.getTimeWindowSeconds();
        String cacheKey = REDIS_KEY_PREFIX + this.getCacheKey(exchange, tokenLimitType, keyName);
        AiStatisticServerHttpResponse loggingServerHttpResponse = new AiStatisticServerHttpResponse(exchange, exchange.getResponse(), tokens -> this.recordTokensUsage(reactiveRedisTemplate, cacheKey, (Long)tokens, timeWindowSeconds));
        return this.isAllowed(reactiveRedisTemplate, cacheKey, tokenLimit).flatMap(allowed -> {
            if (!allowed.booleanValue()) {
                exchange.getResponse().setStatusCode((HttpStatusCode)HttpStatus.TOO_MANY_REQUESTS);
                Consumer consumer = (Consumer)exchange.getAttribute("metricsRateLimiter");
                Optional.ofNullable(consumer).ifPresent(c -> c.accept(exchange.getResponse().getStatusCode()));
                Object error = ShenyuResultWrap.error((ServerWebExchange)exchange, (ShenyuResultEnum)ShenyuResultEnum.RUN_OUT_OF_TOKENS);
                return WebFluxResultUtils.result((ServerWebExchange)exchange, (Object)error);
            }
            ServerWebExchange mutatedExchange = exchange.mutate().response((ServerHttpResponse)loggingServerHttpResponse).build();
            return chain.execute(mutatedExchange);
        });
    }

    private Mono<Boolean> isAllowed(ReactiveRedisTemplate reactiveRedisTemplate, String cacheKey, Long tokenLimit) {
        return reactiveRedisTemplate.opsForValue().get((Object)cacheKey).defaultIfEmpty((Object)0L).flatMap(currentTokens -> {
            if (Long.parseLong(currentTokens.toString()) >= tokenLimit) {
                return Mono.just((Object)false);
            }
            return Mono.just((Object)true);
        });
    }

    private String getCacheKey(ServerWebExchange exchange, String tokenLimitType, String keyName) {
        ServerHttpRequest request = exchange.getRequest();
        AiTokenLimiterEnum tokenLimiterEnum = AiTokenLimiterEnum.getByName((String)tokenLimitType);
        String key = switch (tokenLimiterEnum) {
            case AiTokenLimiterEnum.IP -> Objects.requireNonNull(request.getRemoteAddress()).getHostString();
            case AiTokenLimiterEnum.URI -> request.getURI().getPath();
            case AiTokenLimiterEnum.HEADER -> request.getHeaders().getFirst(keyName);
            case AiTokenLimiterEnum.PARAMETER -> (String)request.getQueryParams().getFirst((Object)keyName);
            case AiTokenLimiterEnum.COOKIE -> {
                HttpCookie cookie = (HttpCookie)request.getCookies().getFirst((Object)keyName);
                yield Objects.nonNull(cookie) ? cookie.getValue() : "";
            }
            default -> (String)exchange.getAttribute("contextPath");
        };
        return StringUtils.isBlank((CharSequence)key) ? "" : key;
    }

    private void recordTokensUsage(ReactiveRedisTemplate reactiveRedisTemplate, String cacheKey, Long tokens, Long windowSeconds) {
        reactiveRedisTemplate.opsForValue().increment((Object)cacheKey, tokens.longValue()).flatMap(currentValue -> reactiveRedisTemplate.expire((Object)cacheKey, Duration.ofSeconds(windowSeconds))).subscribe();
    }

    public int getOrder() {
        return PluginEnum.AI_TOKEN_LIMITER.getCode();
    }

    public String named() {
        return PluginEnum.AI_TOKEN_LIMITER.getName();
    }

    static class AiStatisticServerHttpResponse
    extends ServerHttpResponseDecorator {
        private final ServerWebExchange exchange;
        private final ServerHttpResponse serverHttpResponse;
        private final Consumer<Long> tokensRecorder;

        AiStatisticServerHttpResponse(ServerWebExchange exchange, ServerHttpResponse delegate, Consumer<Long> tokensRecorder) {
            super(delegate);
            this.exchange = exchange;
            this.serverHttpResponse = delegate;
            this.tokensRecorder = tokensRecorder;
        }

        @NonNull
        public Mono<Void> writeWith(@NonNull Publisher<? extends DataBuffer> body) {
            return super.writeWith(this.appendResponse(body));
        }

        @NonNull
        private Flux<? extends DataBuffer> appendResponse(Publisher<? extends DataBuffer> body) {
            BodyWriter writer = new BodyWriter();
            return Flux.from(body).doOnNext(buffer -> {
                try (DataBuffer.ByteBufferIterator bufferIterator = buffer.readableByteBuffers();){
                    bufferIterator.forEachRemaining(byteBuffer -> {
                        if (this.serverHttpResponse.getHeaders().containsKey((Object)"Content-Encoding") && this.serverHttpResponse.getHeaders().getFirst("Content-Encoding").contains("gzip")) {
                            try {
                                ByteBuffer readOnlyBuffer = byteBuffer.asReadOnlyBuffer();
                                byte[] compressed = new byte[readOnlyBuffer.remaining()];
                                readOnlyBuffer.get(compressed);
                                byte[] decompressed = this.decompressGzip(compressed);
                                writer.write(ByteBuffer.wrap(decompressed));
                            }
                            catch (IOException e) {
                                LOG.error("Failed to decompress gzipped response", (Throwable)e);
                                writer.write(byteBuffer.asReadOnlyBuffer());
                            }
                        } else {
                            writer.write(byteBuffer.asReadOnlyBuffer());
                        }
                    });
                }
            }).doFinally(signal -> {
                String responseBody = writer.output();
                AiModel aiModel = (AiModel)this.exchange.getAttribute("ai_model");
                long tokens = Objects.requireNonNull(aiModel).getCompletionTokens(responseBody);
                this.tokensRecorder.accept(tokens);
            });
        }

        private byte[] decompressGzip(byte[] compressed) throws IOException {
            try (GZIPInputStream gzipInputStream = new GZIPInputStream(new ByteArrayInputStream(compressed));){
                byte[] byArray;
                try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream();){
                    int len;
                    byte[] buffer = new byte[1024];
                    while ((len = gzipInputStream.read(buffer)) > 0) {
                        outputStream.write(buffer, 0, len);
                    }
                    byArray = outputStream.toByteArray();
                }
                return byArray;
            }
        }
    }

    static class BodyWriter {
        private final ByteArrayOutputStream stream = new ByteArrayOutputStream();
        private final WritableByteChannel channel = Channels.newChannel(this.stream);
        private final AtomicBoolean isClosed = new AtomicBoolean(false);

        BodyWriter() {
        }

        void write(ByteBuffer buffer) {
            if (!this.isClosed.get()) {
                try {
                    this.channel.write(buffer);
                }
                catch (IOException e) {
                    this.isClosed.compareAndSet(false, true);
                    LOG.error("Parse Failed.", (Throwable)e);
                }
            }
        }

        boolean isEmpty() {
            return this.stream.size() == 0;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        String output() {
            try {
                this.isClosed.compareAndSet(false, true);
                String string = this.stream.toString(StandardCharsets.UTF_8);
                return string;
            }
            catch (Exception e) {
                LOG.error("Write failed: ", (Throwable)e);
                String string = "Write failed: " + e.getMessage();
                return string;
            }
            finally {
                try {
                    this.stream.close();
                }
                catch (IOException e) {
                    LOG.error("Close stream error: ", (Throwable)e);
                }
                try {
                    this.channel.close();
                }
                catch (IOException e) {
                    LOG.error("Close channel error: ", (Throwable)e);
                }
            }
        }
    }
}

