package com.ejianc.foundation.ai.config;

import com.alibaba.fastjson.JSONObject;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.ejianc.foundation.ai.bean.KnowledgeEmbeddingPointsEntity;
import com.ejianc.foundation.ai.bean.ModelEntity;
import com.ejianc.foundation.ai.service.IExtDictService;
import com.ejianc.foundation.ai.service.IKnowledgeEmbeddingPointsService;
import com.ejianc.foundation.ai.service.IModelService;
import com.ejianc.framework.core.context.InvocationInfoProxy;
import com.ejianc.framework.skeleton.refer.util.ContextUtil;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.onnx.allminilml6v2.AllMiniLmL6V2EmbeddingModel;
import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.qianfan.QianfanEmbeddingModel;
import dev.langchain4j.model.qianfan.QianfanEmbeddingModelNameEnum;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.redis.RedisEmbeddingStore;
import org.apache.commons.lang.StringUtils;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.wltea.analyzer.lucene.IKAnalyzer;

import java.io.IOException;
import java.io.StringReader;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class EjcAiEmbeding {
    private Logger logger = LoggerFactory.getLogger(this.getClass());

    private String redisHost;
    private Integer redisPort;
    private String redisPassword;
    private IModelService modelService;
//    private final static Analyzer analyzer = new IKAnalyzer(true);

    public EjcAiEmbeding() {}

//    public Analyzer getAnalyzer() {
//        return analyzer;
//    }

    public EjcAiEmbeding(String redisHost, Integer redisPort, String redisPassword, IModelService modelService) {
        this.redisHost = redisHost;
        this.redisPort = redisPort;
        this.redisPassword = redisPassword;
        this.modelService = modelService;
    }

    private final Map<String, EmbeddingStore<TextSegment>> embeddingStoreMap = new HashMap<>();
    private final Map<String, EmbeddingModel> embeddingModelMap = new HashMap<>();
    private final Map<String, Integer> dimensionMap = new HashMap<>();
    private EmbeddingModel embeddingModel = null;

    public EmbeddingModel getEmbeddingModel(String indexName) {
        synchronized (indexName) {
            EmbeddingModel embeddingModel = embeddingModelMap.get(indexName);
            if(embeddingModel != null) {
                return embeddingModel;
            }
            //查询切片模型
            QueryWrapper<ModelEntity> queryWrapper = new QueryWrapper<>();
            queryWrapper.eq("model_state", 1);
            queryWrapper.eq("embedding_model_flag", 1);
            List<ModelEntity> embeddingModelList = modelService.list(queryWrapper);
            if(embeddingModelList != null && embeddingModelList.size() > 0){
                ModelEntity modelEntity = embeddingModelList.get(0);
                if(1 == modelEntity.getPlatform()) { //ollama
                    System.out.println(modelEntity.getBaseUrl());
                    System.out.println(modelEntity.getModelName());

                    embeddingModel = OllamaEmbeddingModel.builder().baseUrl(modelEntity.getBaseUrl()).modelName(modelEntity.getModelName()).timeout(Duration.ofMinutes(5)).build();

                    embeddingModelMap.put(indexName, embeddingModel);
                    Response<Embedding> embedding = embeddingModel.embed("test");
                    dimensionMap.put(indexName, embedding.content().dimension());
                }else if(2 == modelEntity.getPlatform()){ //千帆
                    if(StringUtils.isNotBlank(modelEntity.getAppConfig())) {
                        try {
                            JSONObject appConfigJson = JSONObject.parseObject(modelEntity.getAppConfig());
                            String apiKey = appConfigJson.getString("apiKey");
                            String secretKey = appConfigJson.getString("secretKey");
                            //EMBEDDING_V1 :  384
                            embeddingModel = QianfanEmbeddingModel.builder().modelName(QianfanEmbeddingModelNameEnum.EMBEDDING_V1.getModelName()).apiKey(apiKey).secretKey(secretKey).build();
                            embeddingModelMap.put(indexName, embeddingModel);
                            Response<Embedding> embedding = embeddingModel.embed("test");
                            dimensionMap.put(indexName, embedding.content().dimension());
                        } catch (Exception e) {}
                    }
                }else if(3 == modelEntity.getPlatform()) { //阿里云百练
                    if(StringUtils.isNotBlank(modelEntity.getAppConfig())) {
                        try {
                            JSONObject appConfigJson = JSONObject.parseObject(modelEntity.getAppConfig());
                            String apiKey = appConfigJson.getString("apiKey");
                            embeddingModel = OpenAiEmbeddingModel.builder().baseUrl("https://dashscope.aliyuncs.com/compatible-mode/v1").modelName("text-embedding-v3").apiKey(apiKey).build();
                            embeddingModelMap.put(indexName, embeddingModel);
                            Response<Embedding> embedding = embeddingModel.embed("test");
                            dimensionMap.put(indexName, embedding.content().dimension());
                        } catch (Exception e) {}
                    }
                }
            }else{
                embeddingModel = new AllMiniLmL6V2EmbeddingModel();
                embeddingModelMap.put(indexName, embeddingModel);
                dimensionMap.put(indexName, 384);
            }
            return embeddingModel;
        }
    }

    public EmbeddingStore<TextSegment> getEmbeddingStore(String indexName) {
        synchronized (indexName) {
            EmbeddingStore<TextSegment> embeddingStore = embeddingStoreMap.get(indexName);
            if(embeddingStore != null) {
                return embeddingStore;
            }
            int dimension = dimensionMap.get(indexName);
            embeddingStore = RedisEmbeddingStore.builder().host(redisHost).port(redisPort).user("default").password(redisPassword).dimension(dimension).indexName(indexName).build();
            embeddingStoreMap.put(indexName, embeddingStore);
            return embeddingStore;
        }
    }

    public Map<String, String> getMatchList(String indexName, String searchText, Double matchScore, Integer maxResults) {
        Map<String, String> matchMap = new HashMap<>();
        Embedding queryEmbedding = getEmbeddingModel(indexName).embed(searchText).content();
        EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder().queryEmbedding(queryEmbedding).minScore(matchScore).maxResults(maxResults).build();
        EmbeddingSearchResult<TextSegment> embeddedEmbeddingSearchResult = getEmbeddingStore(indexName).search(embeddingSearchRequest);
        List<EmbeddingMatch<TextSegment>> embeddingMatcheList = embeddedEmbeddingSearchResult.matches();

        if(embeddingMatcheList.size() > 0) {
            List<String> words = new ArrayList<>();
            TokenStream tokenStream = null;
            try {
                Map<String, String> authHeader = new HashMap<>();
                authHeader.put(AuthHeaderUtils.KNOWLEDGEBASECODE, indexName);
                AuthHeaderUtils.setAuthHeader(authHeader);
                Analyzer analyzer = new IKAnalyzer(true);
                tokenStream = analyzer.tokenStream("content", new StringReader(searchText));
                tokenStream.reset();
                CharTermAttribute termAttr = tokenStream.getAttribute(CharTermAttribute.class);
                while (tokenStream.incrementToken()) {
                    words.add(termAttr.toString());
                }
            } catch (IOException e) {}finally {
                if(tokenStream != null) {
                    try {
                        tokenStream.close();
                    }catch(Exception e) {}
                }
            }
            Map<Long, Long> pointsMap = new HashMap<>();
            for(EmbeddingMatch<TextSegment> embeddingMatch : embeddingMatcheList) {
                TextSegment textSegment = embeddingMatch.embedded();
                if(StringUtils.isNotBlank(textSegment.text())) {
                    for(String word : words) {
                        if(textSegment.text().contains(word)) {
                            QueryWrapper<KnowledgeEmbeddingPointsEntity> pointsWrapper = new QueryWrapper<>();
                            if(embeddingMatch.embeddingId()!=null && embeddingMatch.embeddingId().contains("tag")){
                                String pointId = embeddingMatch.embeddingId().split("-")[1];
                                pointsWrapper.like("id", pointId);
                            }else{
                                pointsWrapper.like("uuid", embeddingMatch.embeddingId());
                            }
                            IKnowledgeEmbeddingPointsService knowledgeEmbeddingPointsService = ContextUtil.getBean("knowledgeEmbeddingPointsService", IKnowledgeEmbeddingPointsService.class);
                            List<KnowledgeEmbeddingPointsEntity> pointsList = knowledgeEmbeddingPointsService.list(pointsWrapper);
                            if(pointsList!=null && pointsList.size()>0){
                                KnowledgeEmbeddingPointsEntity pointEntity = pointsList.get(0);
                                if(pointsMap.containsKey(pointEntity.getId())) {
                                    continue;
                                }
                                matchMap.put(indexName+":"+embeddingMatch.embeddingId(), pointEntity.getContent());
                                pointsMap.put(pointEntity.getId(), pointEntity.getId());
                                break;
                            }

                        }
                    }
                }
            }
        }
        return matchMap;
    }
}
