package com.yonyou.uap.tenant.web.filter.security;

import java.io.IOException;
import java.util.Map;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.sql.DataSource;

import org.apache.commons.collections.MapUtils;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.ApplicationContext;
import org.springframework.jdbc.core.JdbcTemplate;

import com.yonyou.iuap.security.rest.common.Credential;
import com.yonyou.iuap.security.rest.common.SignProp;
import com.yonyou.iuap.security.rest.exception.UAPSecurityException;
import com.yonyou.iuap.security.rest.factory.ServerVerifyFactory;
import com.yonyou.iuap.security.rest.utils.PostParamsHelper;
import com.yonyou.iuap.security.rest.utils.SignPropGenerator;

import uap.web.cache.CacheManager;
import uap.web.core.ContextHolder;

public class SSORestFulSecurityFilter implements Filter {
	private static final Logger logger = LoggerFactory.getLogger(SSORestFulSecurityFilter.class);
	private static final long DEFAULT_EXPIRED = 300000L;

	public void init(FilterConfig filterConfig) throws ServletException {
	}

	public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain)
			throws IOException, ServletException {
		if ((request instanceof HttpServletRequest)) {
			HttpServletRequest httpReq = (HttpServletRequest) request;
			String sign = httpReq.getHeader("sign");
			String appid = httpReq.getHeader("appId");

			if (StringUtils.isEmpty(sign)) {
				sign = httpReq.getParameter("sign");
			}
			if (StringUtils.isEmpty(appid)) {
				appid = httpReq.getParameter("appId");
			}

			if ((sign == null) || (appid == null)) {
				HttpServletResponse rp = (HttpServletResponse) response;
				rp.setStatus(400);
				rp.addHeader("restful validate error",
						" 400 , Method Not Allowed,please check restful called paramters ! ");

				rp.getWriter().write("Method Not Allowed,please check restful called paramters !");

				return;
			}

			String ts = httpReq.getParameter("ts");
			if (StringUtils.isNumeric(ts)) {
				long sendTs = Long.parseLong(ts);
				if (System.currentTimeMillis() - sendTs > DEFAULT_EXPIRED) {
					HttpServletResponse rp = (HttpServletResponse) response;
					rp.setStatus(400);
					rp.addHeader("restful validate error", " 400 , 请求超时");
					logger.error("restful 签名超超时");
					rp.getWriter().write("restful validate over time");
					return;
				}
			}

			boolean passed = validatorURL(httpReq);
			if (!passed) {
				HttpServletResponse rp = (HttpServletResponse) response;
				rp.setStatus(400);
				rp.addHeader("restful validate error",
						" 400 , Method Not Allowed,please check restful called paramters ! ");

				rp.getWriter().write("Method Not Allowed,please check restful called paramters !");

				return;
			}
		}
		filterChain.doFilter(request, response);
	}

	public void destroy() {
	}

	public Boolean validatorURL(HttpServletRequest httpReq) {
		String sign = httpReq.getHeader("sign");

		if (StringUtils.isEmpty(sign)) {
			sign = httpReq.getParameter("sign");
		}

		String appid = httpReq.getParameter("appId");
		sign = StringUtils.isEmpty(sign) ? httpReq.getHeader("sign") : sign;

		appid = StringUtils.isEmpty(appid) ? httpReq.getHeader("appId") : appid;

		if ((StringUtils.isNotEmpty(sign)) && (StringUtils.isNotEmpty(appid))) {
			try {
				String url = httpReq.getRequestURL().toString();
				if (StringUtils.isNotBlank(httpReq.getQueryString())) {
					url = url + "?" + httpReq.getQueryString();
				}
				SignProp prop = SignPropGenerator.genSignProp(url);

				if (httpReq.getMethod().endsWith("POST")) {
					prop.setPostParamsStr(PostParamsHelper.genParamsStrByReqeust(httpReq));

					prop.setContentLength(httpReq.getContentLength());
				}

				RestServerVirifyFactory factory = new RestServerVirifyFactory();
				return factory.getVerifier(appid).verify(sign, prop);
			} catch (UAPSecurityException e) {
				logger.error("restful摘要计算失败!", e);
			} catch (Exception e) {
				logger.error("restful摘要计算失败!", e);
			}
		}
		return Boolean.FALSE;
	}

	class RestServerVirifyFactory extends ServerVerifyFactory {
		private Logger log = LoggerFactory.getLogger(ServerVerifyFactory.class);

		RestServerVirifyFactory() {
		}

		protected Credential genCredential(String appId) {
			try {
				ApplicationContext context = ContextHolder.getContext();
				CacheManager cacheManager = (CacheManager) context.getBean("cacheManager");
				Credential redis_credential = (Credential) cacheManager.get("temp_sso_trust_" + appId);
				if (redis_credential != null) {
					return redis_credential;
				}

				context.getBean("dataSource");
				JdbcTemplate jdbcTemplate = new JdbcTemplate((DataSource) context.getBean("dataSource"));
				String sql = " select * from pub_security_client where client_id = ? ";
				Map<String,Object> map = jdbcTemplate.queryForMap(sql, new Object[] { appId });

				String clientId = MapUtils.getString(map, "client_id");
				String clientKey = MapUtils.getString(map, "client_key");
				String expiredTs = MapUtils.getString(map, "expired_ts");
				Credential credential = new Credential(clientId, clientKey, expiredTs);
				cacheManager.setex("temp_sso_trust_" + appId, credential, 3600);

				return credential;
			} catch (Exception e) {
				this.log.error(e.getMessage(), e);
			}
			return null;
		}
	}
}