package com.github.aidensuen.mongo.spring;

import com.github.aidensuen.mongo.session.Configuration;
import com.github.aidensuen.mongo.session.ExecutorType;
import com.github.aidensuen.mongo.session.MongoSession;
import com.github.aidensuen.mongo.session.MongoSessionFactory;
import com.mongodb.client.MongoDatabase;
import com.mongodb.client.result.DeleteResult;
import com.mongodb.client.result.UpdateResult;
import org.springframework.core.NestedExceptionUtils;
import org.springframework.dao.support.PersistenceExceptionTranslator;
import org.springframework.data.domain.Pageable;
import org.springframework.data.mongodb.core.MongoExceptionTranslator;
import org.springframework.data.mongodb.core.aggregation.AggregationResults;
import org.springframework.util.Assert;

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

public class MongoSessionTemplate implements MongoSession {

    private final MongoSessionFactory mongoSessionFactory;
    private final ExecutorType executorType;
    private final MongoSession mongoSessionProxy;
    private final PersistenceExceptionTranslator exceptionTranslator;

    public MongoSessionTemplate(MongoSessionFactory mongoSessionFactory) {
        this(mongoSessionFactory, mongoSessionFactory.getConfiguration().getDefaultExecutorType());
    }

    public MongoSessionTemplate(MongoSessionFactory mongoSessionFactory, ExecutorType executorType) {
        this(mongoSessionFactory, executorType, new MongoExceptionTranslator());
    }

    public MongoSessionTemplate(MongoSessionFactory mongoSessionFactory, ExecutorType executorType, PersistenceExceptionTranslator exceptionTranslator) {
        Assert.notNull(mongoSessionFactory, "Property 'mongoSessionFactory' is required");
        Assert.notNull(executorType, "Property 'executorType' is required");
        this.mongoSessionFactory = mongoSessionFactory;
        this.executorType = executorType;
        this.exceptionTranslator = exceptionTranslator;
        this.mongoSessionProxy = (MongoSession) Proxy.newProxyInstance(MongoSessionFactory.class.getClassLoader(), new Class[]{MongoSession.class}, new MongoSessionTemplate.MongoSessionIntercetor());
    }

    public MongoSessionFactory getMongoSessionFactory() {
        return mongoSessionFactory;
    }

    public ExecutorType getExecutorType() {
        return executorType;
    }

    public PersistenceExceptionTranslator getExceptionTranslator() {
        return exceptionTranslator;
    }

    @Override
    public MongoDatabase getDb() {
        return this.mongoSessionProxy.getDb();
    }

    @Override
    public Configuration getConfiguration() {
        return this.mongoSessionFactory.getConfiguration();
    }

    @Override
    public <T> T save(String statement, T objectToSave) {
        return this.mongoSessionProxy.save(statement, objectToSave);
    }

    @Override
    public <T> T insert(String statement, T objectToSave) {
        return this.mongoSessionProxy.insert(statement, objectToSave);
    }

    @Override
    public <T> Collection<T> insert(String statement, Collection<? extends T> batchToSave) {
        return this.mongoSessionProxy.insert(statement, batchToSave);
    }

    @Override
    public DeleteResult remove(String statement, Object parameter) {
        return this.mongoSessionProxy.remove(statement, parameter);
    }

    @Override
    public <T> T findOne(String statement, Object parameter) {
        return this.mongoSessionProxy.findOne(statement, parameter);
    }

    @Override
    public <T> List<T> find(String statement, Object parameter) {
        return this.mongoSessionProxy.find(statement, parameter);
    }

    @Override
    public <T> List<T> find(String statement, Object parameter, Pageable pageable) {
        return this.mongoSessionProxy.find(statement, parameter, pageable);
    }

    @Override
    public <T, R> List<R> find(String statement, Object parameter, Function<T, R> converter) {
        return this.mongoSessionProxy.find(statement, parameter, converter);
    }

    @Override
    public <T, R> List<R> find(String statement, Object parameter, Pageable pageable, Function<T, R> converter) {
        return this.mongoSessionProxy.find(statement, parameter, pageable, converter);
    }

    @Override
    public <k, v> Map<k, v> findMap(String statement, Object parameter, String mapKey) {
        return this.mongoSessionProxy.findMap(statement, parameter, mapKey);
    }

    @Override
    public <k, v> Map<k, v> findMap(String statement, Object parameter, String mapKey, Pageable pageable) {
        return this.mongoSessionProxy.findMap(statement, parameter, mapKey, pageable);
    }

    @Override
    public long count(String statement, Object parameter) {
        return this.mongoSessionProxy.count(statement, parameter);
    }

    @Override
    public boolean exists(String statement, Object parameter) {
        return this.mongoSessionProxy.exists(statement, parameter);
    }

    @Override
    public UpdateResult updateFirst(String statement, Object parameter) {
        return this.mongoSessionProxy.updateFirst(statement, parameter);
    }

    @Override
    public UpdateResult updateMulti(String statement, Object parameter) {
        return this.mongoSessionProxy.updateMulti(statement, parameter);
    }

    @Override
    public UpdateResult upsert(String statement, Object parameter) {
        return this.mongoSessionProxy.upsert(statement, parameter);
    }

    @Override
    public <O> AggregationResults<O> aggregate(String statement, Object parameter) {
        return this.mongoSessionProxy.aggregate(statement, parameter);
    }

    @Override
    public <T> T getMongoDao(Class<T> type) {
        return this.getConfiguration().getMongoDao(type, this);
    }

    private class MongoSessionIntercetor implements InvocationHandler {

        private MongoSessionIntercetor() {
        }

        @Override
        public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
            MongoSession mongoSession = MongoSessionUtils.getMongoSession(MongoSessionTemplate.this.mongoSessionFactory, MongoSessionTemplate.this.executorType, MongoSessionTemplate.this.exceptionTranslator);
            Object unwrapped;
            try {
                Object result = method.invoke(mongoSession, args);
                unwrapped = result;
            } catch (Throwable throwable) {
                unwrapped = NestedExceptionUtils.getMostSpecificCause(throwable);
                if (MongoSessionTemplate.this.exceptionTranslator != null && unwrapped instanceof RuntimeException) {
                    mongoSession = null;
                    Throwable translated = MongoSessionTemplate.this.exceptionTranslator.translateExceptionIfPossible((RuntimeException) unwrapped);
                    if (translated != null) {
                        unwrapped = translated;
                    }
                    throw (Throwable) unwrapped;
                }
            }
            return unwrapped;
        }
    }
}
