/*
 * The contents of this file are subject to the terms
 * of the Common Development and Distribution License
 * (the License).  You may not use this file except in
 * compliance with the License.
 *
 * You can obtain a copy of the license at
 * https://glassfish.dev.java.net/public/CDDLv1.0.html or
 * glassfish/bootstrap/legal/CDDLv1.0.txt.
 * See the License for the specific language governing
 * permissions and limitations under the License.
 *
 * When distributing Covered Code, include this CDDL
 * Header Notice in each file and include the License file
 * at glassfish/bootstrap/legal/CDDLv1.0.txt.
 * If applicable, add the following below the CDDL Header,
 * with the fields enclosed by brackets [] replaced by
 * you own identifying information:
 * "Portions Copyrighted [year] [name of copyright owner]"
 *
 * Copyright (c) Ericsson AB, 2004-2007. All rights reserved.
 */
package org.jvnet.glassfish.comms.clb.core.sip;

import com.ericsson.ssa.sip.Header;
import com.ericsson.ssa.sip.MultiLineHeader;
import com.ericsson.ssa.sip.SipServletMessageImpl;
import com.ericsson.ssa.sip.SipServletResponseImpl;
import com.ericsson.ssa.sip.ViaImpl;
import com.ericsson.ssa.sip.dns.SipTransports;
import com.ericsson.ssa.sip.dns.TargetTuple;

import org.jvnet.glassfish.comms.clb.core.CLBConstants;
import org.jvnet.glassfish.comms.clb.core.ConsistentHashRequest;
import org.jvnet.glassfish.comms.clb.core.EndPoint;
import org.jvnet.glassfish.comms.clb.core.Router;
import org.jvnet.glassfish.comms.clb.core.ServerInstance;
import org.jvnet.glassfish.comms.clb.core.common.chr.StickyHashKeyExtractor;
import org.jvnet.glassfish.comms.clb.core.util.LoadbalancerUtil;
import org.jvnet.glassfish.comms.util.LogUtil;

import java.io.ByteArrayInputStream;
import java.io.IOException;

import java.net.InetAddress;
import java.net.UnknownHostException;

import java.security.cert.CertificateEncodingException;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;

import java.util.ArrayList;
import java.util.List;
import java.util.ListIterator;
import java.util.StringTokenizer;
import java.util.logging.Level;

import javax.servlet.sip.SipServletRequest;


/**
 * This class implements the SIP load balancer routing logic. It implements both
 * the front-end logic (incoming requests and responses) and back-end logic
 * (both incoming and outgoing requests and responses)
 */
class SipLoadBalancerIncomingHandler {
    private static LogUtil logger = new LogUtil(LogUtil.CLB_LOG_DOMAIN);

    /** The name of the attribute for X.509 certificate on responses. */
    public static final String RESPONSE_CERTIFICATE_ATTR = "javax.servlet.response.X509Certificate";

    /** The name of the attribute for X.509 certificate on requests. */
    public static final String REQUEST_CERTIFICATE_ATTR = "javax.servlet.request.X509Certificate";

    /** The URI parameter name for the route to the Back-End. */
    public static final String BE_ROUTE_PARAM = "beroute";

    /** The request attribute name for the encoded connection ID. */
    public static final String CONNID_ATTR = SipLoadBalancerIncomingHandler.class.getName() +
        ".CONN_ID";

    /** The URI parameter name for the encoded connection ID. */
    public static final String CONNID_PARAM = "connid";

    /**
     * The request attribute name for the flag indicating that the Via was set
     * by the load balancer.
     */
    public static final String FE_LB_PARAM = "felb";
    private Socket localSipTcpSocket; // Socket for traffic between F-E to B-E and responses

    /**
     * Creates an instance and associates it with the specified hash key
     * extractor and server instance lookup.
     *
     */
    public SipLoadBalancerIncomingHandler(Socket localSipTcpSocket) {
        this.localSipTcpSocket = localSipTcpSocket;
    }

    /**
     * Handle a request that as been received by this instance. It shall either
     * be served by this instance, or be proxied to another instance.
     *
     * @param req the request; the request may be modified (headers added or
     *                changed)
     * @param serverSelector the server selector
     * @return a possible new connection; null means continue on this server
     * @throws SipRoutingException thrown in case request was malformed
     */
    public Connection handleIncomingRequest(ConsistentHashRequest req,
        Router router) throws SipRoutingException {
        if (logger.isLoggable(Level.FINE)) {
            logger.logMsg(Level.FINE,
                "Handle incoming request: " + req.getSipRequest());
        }

        ServerInstance serverInstance = router.selectInstance(req);
        String hashkey = req.getHashKey();

        SipServletMessageImpl request = (SipServletMessageImpl) req.getSipRequest();
        request.setBeKey(hashkey);

        if (logger.isLoggable(Level.FINE)) {
            logger.logMsg(Level.FINE,
                "Hash key: " + hashkey + "; server: " +
                serverInstance.getName());
        }

        if (serverInstance == null) {
            throw new SipRoutingException("Could not find a server");
        }

        try {
            if (logger.isLoggable(Level.FINE)) {
                try {
                    logger.logMsg(Level.FINE,
                        "Check if this is the right server, local: " +
                        getLocalSipTcpSocket() +
                        "), selected from consistent hash: " +
                        getServerAddress(serverInstance.getEndPoint(
                                request.getTransport())) + ")");
                } catch (Exception e) {
                    // Ignore
                }
            }

            if (serverInstance.getEndPoint(CLBConstants.SIP_PROTOCOL).isLocal()) {
                // Continue on this instance
                saveIncomingConnection(request);
                decodeClientCert(request);
                decodeRemote(request);

                if (logger.isLoggable(Level.FINE)) {
                    logger.logMsg(Level.FINE,
                        "Continue on this instance; request: " + request);
                }

                return null;
            } else {
                // Proxy to the other instance
                Socket serverAddress = getServerAddress(serverInstance.getEndPoint(
                            CLBConstants.SIP_PROTOCOL));
                pushVia(request, getLocalSipTcpSocket());
                encodeRemote(request);
                encodeClientCert(request);
                encodeBeKeyHeader(request, hashkey);

                if (logger.isLoggable(Level.FINE)) {
                    logger.logMsg(Level.FINE,
                        "Proxy to other instance (" + serverAddress +
                        "); request: " + request);
                }

                saveIncomingConnection(request);

                return new Connection(SipTransports.TCP_PROT, null,
                    serverAddress);
            }
        } catch (UnknownHostException e) {
            throw new SipRoutingException("Could not extract host: " + e);
        }
    }

    private void saveIncomingConnection(SipServletMessageImpl msg) {
        msg.setAttribute(CONNID_ATTR,
            new Connection(msg.getRemote().getProtocol(),
                new Socket(msg.getLocal().getHostName(),
                    msg.getLocal().getPort()),
                new Socket(msg.getRemote().getIP(), msg.getRemote().getPort())));
    }

    private Socket getServerAddress(EndPoint serverAddress)
        throws UnknownHostException {
        return new Socket(LoadbalancerUtil.getNumericIpAddress(
                InetAddress.getByName(serverAddress.getHost())),
            serverAddress.getPort());
    }

    /**
     * Handle a response that as been received by this instance. It shall either
     * be served by this instance, or be re-routed to another instance.
     *
     * @param response the response; the response may be modified (headers added
     *                or changed)
     * @return a possible list of new connection which shall be tried in the
     *         order specified in the list; if null continue on the current
     *         server.
     * @throws SipRoutingException thrown in case response was corrupt, the
     *                 caller shall do no further processing but just drop the
     *                 response.
     */
    public Connection handleIncomingResponse(SipServletResponseImpl response)
        throws SipRoutingException {
        if (logger.isLoggable(Level.FINE)) {
            logger.logMsg(Level.FINE, "Handle incoming response: " + response);
        }

        // Get topmost header
        Header viaHeader = response.getRawHeader(Header.VIA);

        if (viaHeader != null) {
            ListIterator<String> viaIterator;
            viaIterator = viaHeader.getValues();

            if (viaIterator.hasNext()) {
                ViaImpl via = new ViaImpl(viaIterator.next());

                // Extract connection
                Connection connection = getConnection(via);

                // Extract 'felb' flag
                String frontEndLb = via.getParameter(FE_LB_PARAM);

                if (frontEndLb != null) {
                    // This response was received from the back-end,
                    // pop Via, extract connection and forward it.
                    viaHeader.setReadOnly(false);
                    viaIterator.remove();
                    viaHeader.setReadOnly(true);

                    if (logger.isLoggable(Level.FINE)) {
                        logger.logMsg(Level.FINE,
                            "Response from back-end forward to client (" +
                            connection + "); response: " + response);
                    }

                    saveIncomingConnection(response);

                    return connection;
                } else {
                    // This response was received from an external party and
                    // shall possibly be forwarded to a back-end
                    Socket serverAddress = getServerAddressForResponse(response);

                    if (serverAddress != null) {
                        return handleResponseFromExternalParty(response,
                            connection, serverAddress);
                    } else {
                        // Should never happen, drop response
                        throw new SipRoutingException(
                            "No BERoute on response, shall never happen; drop the response!");
                    }
                }
            }
        }

        // No Via? The response must be corrupt, drop it!
        throw new SipRoutingException(
            "No Via on response, shall never happen; drop the response!");
    }

    // ---------------- Internal methods --------------------
    private Socket getServerAddressForResponse(SipServletResponseImpl response)
        throws SipRoutingException {
        // Get topmost Via
        Header viaHeader = response.getRawHeader(Header.VIA);

        if (viaHeader != null) {
            ListIterator<String> viaIterator;
            viaIterator = viaHeader.getValues();

            if (viaIterator.hasNext()) {
                ViaImpl via = new ViaImpl(viaIterator.next());

                return SipLoadBalancerBackend.decodeBeRoute(via.getParameter(
                        BE_ROUTE_PARAM));
            }
        }

        return null;
    }

    private Connection getConnection(ViaImpl via) {
        String connectionIdentity = via.getParameter(CONNID_PARAM);
        Connection connection = null;

        try {
            connection = (connectionIdentity != null)
                ? Connection.getFromEncoded(connectionIdentity) : null;
        } catch (ConnectionParseException e) {
            // The connection could not be decoded
            logger.warning(
                "Error when parsing CONNID parameter of topmost Via: " + e);
        } catch (IOException e) {
            // The connection could not be decoded
            logger.warning(
                "Error when parsing CONNID parameter of topmost Via: " + e);
        }

        return connection;
    }

    private Connection handleResponseFromExternalParty(
        SipServletResponseImpl response, Connection connection,
        Socket serverAddress) throws SipRoutingException {
        // Note, it does not matter that the response actually arrived on another socket.
        // We just want to match it against the socket encoded in 'beroute' to see that 
        // this is the instance that sent the request.
        if (isEqual(getLocalSipTcpSocket(), serverAddress)) {
            // On the right server, continue on it
            if (logger.isLoggable(Level.FINE)) {
                logger.logMsg(Level.FINE,
                    "Response from outside but we are on the right server continue");
            }

            response.setAttribute(CONNID_ATTR, connection); // This is the connection extracted from the Via on the incoming response
            decodeRemote(response);
            decodeClientCert(response);

            return null;
        } else {
            // The response shall be re-routed to another server
            if (logger.isLoggable(Level.FINE)) {
                logger.logMsg(Level.FINE,
                    "Response from outside but we are on wrong server (" +
                    getLocalSipTcpSocket() + ") re-route to: " + serverAddress);
            }

            encodeRemote(response);
            encodeClientCert(response);

            return new Connection(SipTransports.TCP_PROT, null, serverAddress);
        }
    }

    private void pushVia(SipServletMessageImpl request, Socket localAddress) {
        // This Via is used to route an incoming response if the request is proxied and the 
        // response goes via UDP or if the TCP connection via which the request was sent 
        // is broken.
        Header viaHeader = request.getRawHeader(Header.VIA);

        if (viaHeader == null) {
            viaHeader = new MultiLineHeader(Header.VIA, true);
            request.addHeader(viaHeader);
        }

        ViaImpl via = new ViaImpl("SIP", SipTransports.TCP_PROT.name(),
                localAddress.getHostName(), localAddress.getPort());
        // Set flag indicating that this Via was pushed by the front-end
        via.setParameter(FE_LB_PARAM, null);
        // Encode information about the connection between the client and the
        // front-end
        via.setParameter(CONNID_PARAM,
            new Connection(request.getRemote().getProtocol(),
                new Socket(request.getLocal().getHostName(),
                    request.getLocal().getPort()),
                new Socket(request.getRemote().getIP(),
                    request.getRemote().getPort())).getEncodedValue());
        viaHeader.setValue(via.toString(), true);
    }

    private void encodeClientCert(SipServletMessageImpl message)
        throws SipRoutingException {
        // Encode a possible client certificate
        X509Certificate[] clientCerts = message.getCertificate();

        if (clientCerts != null) {
            for (X509Certificate clientCert : clientCerts) {
                try {
                    message.addHeader(Header.PROXY_AUTH_CERT_HEADER,
                        LoadbalancerUtil.encodeParameter(
                            clientCert.getEncoded(), true));
                } catch (CertificateEncodingException e) {
                    throw new SipRoutingException("Could not encode client certificate.",
                        e);
                }
            }
        }
    }

    private void decodeClientCert(SipServletMessageImpl message)
        throws SipRoutingException {
        ListIterator<String> headers = message.getHeaders(Header.PROXY_AUTH_CERT_HEADER);
        List<X509Certificate> certs = new ArrayList<X509Certificate>();

        while (headers.hasNext()) {
            String header = headers.next();
            message.removeHeader(Header.PROXY_AUTH_CERT_HEADER);

            try {
                CertificateFactory cf = CertificateFactory.getInstance("X.509");
                byte[] clientCertBytes = LoadbalancerUtil.decodeParameterToBytes(header,
                        true);
                ByteArrayInputStream bais = new ByteArrayInputStream(clientCertBytes);
                certs.add((X509Certificate) cf.generateCertificate(bais));
            } catch (CertificateException e) {
                throw new SipRoutingException("Could not decode " +
                    Header.PROXY_AUTH_CERT_HEADER, e);
            } catch (IOException e) {
                throw new SipRoutingException("Could not decode " +
                    Header.PROXY_AUTH_CERT_HEADER, e);
            }
        }

        if (certs.size() > 0) {
            if (message instanceof SipServletRequest) {
                message.setAttribute(REQUEST_CERTIFICATE_ATTR,
                    certs.toArray(new X509Certificate[certs.size()]));
            } else {
                message.setAttribute(RESPONSE_CERTIFICATE_ATTR,
                    certs.toArray(new X509Certificate[certs.size()]));
            }
        }
    }

    private void encodeRemote(SipServletMessageImpl msg) {
        msg.setHeader(Header.PROXY_REMOTE_HEADER,
            msg.getRemote().getProtocol().name() + ":" +
            msg.getRemote().getIP() + ":" + msg.getRemote().getPort());
    }

    private void decodeRemote(SipServletMessageImpl msg)
        throws SipRoutingException {
        String remote = msg.getHeader(Header.PROXY_REMOTE_HEADER);

        if (remote != null) {
            msg.removeHeader(Header.PROXY_REMOTE_HEADER);

            StringTokenizer st = new StringTokenizer(remote, ":", false);
            String[] addrArr = new String[3];
            int i = 0;

            while (st.hasMoreTokens() && (i < addrArr.length)) {
                addrArr[i++] = st.nextToken();
            }

            // Now addrArr[0] = transport, addrArr[1]=IP address and
            // addrArr[2]=port
            if (i == addrArr.length) {
                try {
                    msg.setRemote(new TargetTuple(SipTransports.getTransport(
                                addrArr[0]), addrArr[1],
                            Integer.parseInt(addrArr[2])));
                } catch (NumberFormatException e) {
                    throw new SipRoutingException("Malformed " +
                        Header.PROXY_REMOTE_HEADER, e);
                } catch (Exception e) {
                    throw new SipRoutingException("Malformed " +
                        Header.PROXY_REMOTE_HEADER, e);
                }
            }
        }
    }

    private boolean isEqual(Socket s1, Socket s2) {
        return s1.getHostName().equals(s2.getHostName());
    }

    private void encodeBeKeyHeader(SipServletMessageImpl msg, String bekey) {
        msg.setHeader(StickyHashKeyExtractor.PROXY_BEKEY_HEADER, bekey);
    }

    private Socket getLocalSipTcpSocket() {
        return localSipTcpSocket;
    }
}
