package org.jetlinks.supports.cluster.redis;

import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.jetlinks.core.VisitCount;
import org.jetlinks.core.cluster.ClusterQueue;
import org.jetlinks.core.utils.Reactors;
import org.reactivestreams.Publisher;
import org.springframework.data.redis.connection.lettuce.LettuceConnectionFactory;
import org.springframework.data.redis.core.ReactiveRedisOperations;
import org.springframework.data.redis.core.ReactiveRedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;

import javax.annotation.Nonnull;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.*;

@SuppressWarnings("all")
@Slf4j
public class RedisClusterQueue<T> extends VisitCount implements ClusterQueue<T> {

    private static final AtomicReferenceFieldUpdater<RedisClusterQueue, Boolean> POLLING =
            AtomicReferenceFieldUpdater.newUpdater(RedisClusterQueue.class, Boolean.class, "polling");

    private static final AtomicIntegerFieldUpdater<RedisClusterQueue> ROUND_ROBIN =
            AtomicIntegerFieldUpdater.newUpdater(RedisClusterQueue.class, "roundRobin");

    private final String id;

    protected final ReactiveRedisOperations<String, T> operations;

    private volatile Boolean polling = false;

    private volatile int roundRobin = 0;

    private int maxBatchSize = 32;

    private volatile float localConsumerPercent = 1F;

    private long lastRequestSize = maxBatchSize;

    private boolean hasLocalProducer;

    private Mod mod = Mod.FIFO;

    long lastEmptyTime = 0;

    private final List<FluxSink<T>> subscribers = new CopyOnWriteArrayList<>();

    @Override
    public void setLocalConsumerPercent(float localConsumerPercent) {
        this.localConsumerPercent = localConsumerPercent;
    }

    private static final RedisScript<List> lifoPollScript = RedisScript.of(
            String.join("\n"
                    , "local val = redis.call('lrange',KEYS[1],0,KEYS[2]);"
                    , "redis.call('ltrim',KEYS[1],KEYS[2]+1,-1);"
                    , "return val;")
            , List.class
    );

    private static final RedisScript<List> fifoPollScript = RedisScript.of(
            String.join("\n"
                    , "local size = redis.call('llen',KEYS[1]);"
                    , "if size == 0 then"
                    , "return nil"
                    , "end"
                    , "local index = size - KEYS[2];"
                    , "if index == 0 then"
                    , "return redis.call('lpop',KEYS[1]);"
                    , "end"
                    , "local val = redis.call('lrange',KEYS[1],index,size);"
                    , "redis.call('ltrim',KEYS[1],0,index-1);"
                    , "return val;")
            , List.class
    );

    private static final RedisScript<Long> pushAndPublish = RedisScript.of(
            "local val = redis.call('lpush',KEYS[1],ARGV[1]);" +
                    "redis.call('publish','queue:data:produced',ARGV[2]);" +
                    "return val;"
            , Long.class
    );

    @Setter
    private boolean useScript = "true".equals(System.getProperty("jetlinks.cluster.redus.queue.batch.enabled", "true"));

    public RedisClusterQueue(String id, ReactiveRedisTemplate<String, T> operations) {
        this.id = id;
        this.operations = operations;
        if (useScript && operations.getConnectionFactory() instanceof LettuceConnectionFactory) {
            //cluster不用script
            LettuceConnectionFactory factory = (LettuceConnectionFactory) operations.getConnectionFactory();
            useScript = !factory.isClusterAware();
        }
    }

    protected void tryPoll() {
        doPoll(lastRequestSize);
    }

    private boolean push(Iterable<T> data) {
        for (T datum : data) {
            if (!push(datum)) {
                return false;
            }
        }
        return true;
    }

    private boolean push(T data) {
        int size = subscribers.size();
        if (size == 0) {
            return false;
        }
        if (size == 1) {
            subscribers.get(0).next(data);
            return true;
        }
        int index = ROUND_ROBIN.incrementAndGet(this);
        if (index >= size) {
            ROUND_ROBIN.set(this, index = 0);
        }
        subscribers.get(index).next(data);
        return true;
    }

    private void doPoll(long size) {
        if (!hasLocalConsumer()) {
            return;
        }
        visit();
        if (POLLING.compareAndSet(this, false, true)) {

            AtomicLong total = new AtomicLong(size);
            long pollSize = Math.min(total.get(), maxBatchSize);

            pollBatch((int) pollSize)
                    .flatMap(v -> {
                        //没有订阅者了,重入队列
                        if (!push(v)) {
                            return operations
                                    .opsForList()
                                    .leftPush(id, v)
                                    .then();
                        } else {
                            return Mono.just(v);
                        }
                    })
                    .count()
                    .doFinally((s) -> POLLING.set(this, false))
                    .subscribe(r -> {
                        if (r > 0 && total.addAndGet(-r) > 0) { //继续poll
                            POLLING.set(this, false);
                            doPoll(total.get());
                            log.trace("poll datas[{}] from redis [{}] ", r, id);
                        } else {
                            lastEmptyTime = System.currentTimeMillis();
                        }
                    });
        }
    }

    protected void stopPoll() {

    }

    @Nonnull
    @Override
    public Flux<T> subscribe() {
        return Flux
                .<T>create(sink -> {
                    subscribers.add(sink);
                    sink.onDispose(() -> {
                        subscribers.remove(sink);
                    });
                    doPoll(sink.requestedFromDownstream());
                })
                .doOnRequest(i -> {
                    if (!hasLocalProducer) {
                        doPoll(lastRequestSize = i);
                    }
                });
    }

    @Override
    public void stop() {
        stopPoll();
    }

    @Override
    public boolean hasLocalConsumer() {
        return subscribers.size() > 0;
    }

    @Override
    public Mono<Integer> size() {
        visit();
        return operations.opsForList()
                         .size(id)
                         .map(Number::intValue);
    }

    @Override
    public void setPollMod(Mod mod) {
        this.mod = mod;
    }

    @Nonnull
    @Override
    public Mono<T> poll() {
        visit();
        return mod == Mod.LIFO
                ? operations.opsForList().leftPop(id)
                : operations.opsForList().rightPop(id);
    }

    private Flux<T> pollBatch(int size) {
        if (size == 1 || !useScript) {
            return poll()
                    .flux();
        }
        return
                (
                        mod == Mod.FIFO
                                ? this
                                .operations
                                .execute(fifoPollScript, Arrays.asList(id, String.valueOf(size)))
                                .doOnNext(list -> {
                                    Collections.reverse(list); //先进先出,反转顺序
                                })
                                : this.operations.execute(lifoPollScript, Arrays.asList(id, String.valueOf(size)))
                )
                        .flatMap(list -> {
                            return Flux.create(sink -> {
                                for (Object o : list) {
                                    if (o != null) {
                                        sink.next(o);
                                    }
                                }
                                sink.complete();
                            });
                        })
                        .map(i -> (T) i);

    }

    private ReactiveRedisOperations getOperations() {
        return operations;
    }

    private boolean isLocalConsumer() {
        return subscribers.size() > 0 && (localConsumerPercent == 1F || ThreadLocalRandom
                .current()
                .nextFloat() < localConsumerPercent);
    }

    @Override
    public Mono<Boolean> add(T data) {
        visit();
        return doAdd(data);
    }

    private Mono<Boolean> doAdd(T data) {
        hasLocalProducer = true;
        if (isLocalConsumer() && push(data)) {
            return Reactors.ALWAYS_TRUE;
        } else {
            if (!useScript) {
                return this
                        .operations
                        .opsForList()
                        .leftPush(id, data)
                        .then(getOperations().convertAndSend("queue:data:produced", id));
            }
            return getOperations()
                    .execute(pushAndPublish, Arrays.asList(id), Arrays.asList(data, id))
                    .then(Reactors.ALWAYS_TRUE);
        }
    }

    @Override
    public Mono<Boolean> add(Publisher<T> publisher) {
        hasLocalProducer = true;
        visit();
        return Flux
                .from(publisher)
                .flatMap(this::doAdd)
                .then(Reactors.ALWAYS_TRUE);
    }

    @Override
    public Mono<Boolean> addBatch(Publisher<? extends Collection<T>> publisher) {
        hasLocalProducer = true;
        visit();
        return Flux
                .from(publisher)
                .flatMap(v -> {
                    if (isLocalConsumer() && push(v)) {
                        return Reactors.ALWAYS_ONE;
                    }
                    return this.operations
                            .opsForList()
                            .leftPushAll(id, v)
                            .then(getOperations().convertAndSend("queue:data:produced", id));
                })
                .then(Reactors.ALWAYS_TRUE);
    }
}
