/*
 * Copyright 2006-2007 Sxip Identity Corporation
 */

package jp.sourceforge.tsukuyomi.openid.message;

import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import jp.sourceforge.tsukuyomi.openid.message.ax.AxMessage;
import jp.sourceforge.tsukuyomi.openid.message.pape.PapeMessage;
import jp.sourceforge.tsukuyomi.openid.message.sreg.SRegMessage;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

/**
 * @author Marius Scurtescu, Johnny Bufu
 */
public class Message {
	private static final Log LOG = LogFactory.getLog(Message.class);
	private static final boolean DEBUG = LOG.isDebugEnabled();

	// message constants
	public static final String MODE_IDRES = "id_res";
	public static final String MODE_CANCEL = "cancel";
	public static final String MODE_SETUP_NEEDED = "setup_needed";
	public static final String OPENID2_NS = "http://specs.openid.net/auth/2.0";

	private ParameterList params;
	private int extCounter;

	// extention type URI -> extension alias : extension present in the message
	private Map<String, String> extAliases;

	// extension type URI -> MessageExtensions : extracted extension objects
	private Map<String, MessageExtension> extesion;

	// the URL where this message should be sent, where applicable
	// should remain null for received messages (created from param lists)
	protected String destinationUrl;

	// type URI -> message extension factory : supported extensions
	private static Map<String, Class<?>> extensionFactories =
		new HashMap<String, Class<?>>();

	static {
		extensionFactories.put(AxMessage.OPENID_NS_AX, AxMessage.class);
		extensionFactories.put(SRegMessage.OPENID_NS_SREG, SRegMessage.class);
		extensionFactories.put(PapeMessage.OPENID_NS_PAPE, PapeMessage.class);
	}

	protected Message() {
		params = new ParameterList();
		extCounter = 0;
		extAliases = new HashMap<String, String>();
		extesion = new HashMap<String, MessageExtension>();
	}

	protected Message(ParameterList params) {
		this();
		this.params = params;

		// simple registration is a special case; we support only:
		// SREG1.0 (no namespace, "sreg" alias hardcoded) in OpenID1 messages
		// SREG1.1 (namespace, any possible alias) in OpenID2 messages
		boolean hasSReg10 = false;

		for (Parameter parameter : this.params.getParameters()) {
			String key = parameter.getKey();
			if (key.startsWith("openid.ns.") && key.length() > 10) {
				extAliases.put(this.params.getParameter(key).getValue(), key
					.substring(10));
			}

			if (key.startsWith("openid.sreg.")) {
				hasSReg10 = true;
			}
		}

		// only do the workaround for OpenID1 messages
		if (hasSReg10 && !hasParameter("openid.ns")) {
			extAliases.put(SRegMessage.OPENID_NS_SREG, "sreg");
		}

		extCounter = extAliases.size();
	}

	public static Message createMessage() throws MessageException {
		Message message = new Message();

		if (!message.isValid()) {
			throw new MessageException(
				"Invalid set of parameters for the requested message type");
		}

		if (DEBUG) {
			LOG.debug("Created message:\n" + message.keyValueFormEncoding());
		}

		return message;
	}

	public static Message createMessage(ParameterList params)
			throws MessageException {
		Message message = new Message(params);

		if (!message.isValid()) {
			throw new MessageException(
				"Invalid set of parameters for the requested message type");
		}

		if (DEBUG) {
			LOG.debug("Created message from parameter list:\n"
				+ message.keyValueFormEncoding());
		}

		return message;
	}

	protected Parameter getParameter(String name) {
		return params.getParameter(name);
	}

	public String getParameterValue(String name) {
		return params.getParameterValue(name);
	}

	public boolean hasParameter(String name) {
		return params.hasParameter(name);
	}

	protected void set(String name, String value) {
		params.set(new Parameter(name, value));
	}

	protected List<Parameter> getParameters() {
		return params.getParameters();
	}

	public Map<String, String> getParameterMap() {
		Map<String, String> params = new LinkedHashMap<String, String>();

		for (Parameter p : this.params.getParameters()) {
			params.put(p.getKey(), p.getValue());
		}

		return params;
	}

	/**
	 * Check that all required parameters are present
	 */
	public boolean isValid() {
		List<String> requiredFields = getRequiredFields();

		for (Parameter param : params.getParameters()) {
			if (!param.isValid()) {
				LOG.warn("Invalid parameter: " + param);
				return false;
			}
		}

		if (requiredFields == null) {
			return true;
		}

		for (String required : requiredFields) {
			if (!hasParameter(required)) {
				LOG.warn("Required parameter missing: " + required);
				return false;
			}
		}

		return true;
	}

	public List<String> getRequiredFields() {
		return null;
	}

	public String keyValueFormEncoding() {
		StringBuffer allParams = new StringBuffer("");

		for (Parameter parameter : params.getParameters()) {
			allParams.append(parameter.getKey());
			allParams.append(':');
			allParams.append(parameter.getValue());
			allParams.append('\n');
		}

		return allParams.toString();
	}

	public String wwwFormEncoding() {
		StringBuffer allParams = new StringBuffer("");

		for (Parameter parameter : params.getParameters()) {

			// All of the keys in the request message MUST be prefixed with
			// "openid."
			if (!parameter.getKey().startsWith("openid.")) {
				allParams.append("openid.");
			}

			try {
				allParams
					.append(URLEncoder.encode(parameter.getKey(), "UTF-8"));
				allParams.append('=');
				allParams.append(URLEncoder.encode(
					parameter.getValue(),
					"UTF-8"));
				allParams.append('&');
			} catch (UnsupportedEncodingException e) {
				return null;
			}
		}

		// remove the trailing '&'
		if (allParams.length() > 0) {
			allParams.deleteCharAt(allParams.length() - 1);
		}

		return allParams.toString();
	}

	/**
	 * Gets the URL where the message should be sent, where applicable. Null for
	 * received messages.
	 * 
	 * @param httpGet
	 *            If true, the wwwFormEncoding() is appended to the destination
	 *            URL; the return value should be used with a GET-redirect. If
	 *            false, the verbatim destination URL is returned, which should
	 *            be used with a FORM POST redirect.
	 * 
	 * @see #wwwFormEncoding()
	 */
	public String getDestinationUrl(boolean httpGet) {
		if (destinationUrl == null) {
			throw new IllegalStateException("Destination URL not set; "
				+ "is this a received message?");
		}

		if (httpGet) // append wwwFormEncoding to the destination URL
		{
			boolean hasQuery = destinationUrl.indexOf("?") > 0;
			String initialChar = hasQuery ? "&" : "?";

			return destinationUrl + initialChar + wwwFormEncoding();
		} else {
			// should send the keyValueFormEncoding in POST data
			return destinationUrl;
		}
	}

	// ------------ extensions implementation ------------

	/**
	 * Adds a new extension factory.
	 * 
	 * @param clazz
	 *            The implementation class for the extension factory, must
	 *            implement {@link MessageExtensionFactory}.
	 */
	public static void addExtensionFactory(
			Class<? extends MessageExtensionFactory> clazz)
			throws MessageException {
		MessageExtensionFactory extensionFactory;
		try {
			extensionFactory = clazz.newInstance();
		} catch (InstantiationException e) {
			throw new MessageException(
				"Cannot instantiante message extension factory class: "
					+ clazz.getName());
		} catch (IllegalAccessException e) {
			throw new MessageException(
				"Cannot instantiante message extension factory class: "
					+ clazz.getName());
		}

		if (DEBUG) {
			LOG.debug("Adding extension factory for "
				+ extensionFactory.getTypeUri());
		}

		extensionFactories.put(extensionFactory.getTypeUri(), clazz);
	}

	/**
	 * Returns true if there is an extension factory available for extension
	 * identified by the specified Type URI, or false otherwise.
	 * 
	 * @param typeUri
	 *            The Type URI that identifies an extension.
	 */
	public static boolean hasExtensionFactory(String typeUri) {
		return extensionFactories.containsKey(typeUri);
	}

	/**
	 * Gets a MessageExtensionFactory for the specified Type URI if an
	 * implementation is available, or null otherwise.
	 * 
	 * @param typeUri
	 *            The Type URI that identifies a extension.
	 * @see MessageExtensionFactory Message
	 */
	public static MessageExtensionFactory getExtensionFactory(String typeUri) {
		if (!hasExtensionFactory(typeUri)) {
			return null;
		}

		MessageExtensionFactory extensionFactory;

		try {
			Class<?> extensionClass = extensionFactories.get(typeUri);
			extensionFactory =
				(MessageExtensionFactory) extensionClass.newInstance();
		} catch (Exception e) {
			LOG.error("Error getting extension factory for " + typeUri);
			return null;
		}

		return extensionFactory;
	}

	/**
	 * Returns true if the message has parameters for the specified extension
	 * type URI.
	 * 
	 * @param typeUri
	 *            The URI that identifies the extension.
	 */
	public boolean hasExtension(String typeUri) {
		return extAliases.containsKey(typeUri);
	}

	/**
	 * Gets a set of extension Type URIs that are present in the message.
	 */
	public Set<String> getExtensions() {
		return extAliases.keySet();
	}

	/**
	 * Retrieves the extension alias that will be used for the extension
	 * identified by the supplied extension type URI.
	 * <p>
	 * If the message contains no parameters for the specified extension, null
	 * will be returned.
	 * 
	 * @param extensionTypeUri
	 *            The URI that identifies the extension
	 * @return The extension alias associated with the extension specifid by the
	 *         Type URI
	 */
	public String getExtensionAlias(String extensionTypeUri) {
		return (extAliases.get(extensionTypeUri) != null) ? (String) extAliases
			.get(extensionTypeUri) : null;
	}

	/**
	 * Adds a set of extension-specific parameters to a message.
	 * <p>
	 * The parameter names must NOT contain the "openid.<extension_alias>"
	 * prefix; it will be generated dynamically, ensuring there are no conflicts
	 * between extensions.
	 * 
	 * @param extension
	 *            A MessageExtension containing parameters to be added to the
	 *            message
	 */
	public void addExtension(MessageExtension extension)
			throws MessageException {
		String typeUri = extension.getTypeUri();

		if (hasExtension(typeUri)) {
			throw new MessageException("Extension already present: " + typeUri);
		}

		String alias = "ext" + Integer.toString(++extCounter);

		// use the hardcoded "sreg" alias for SREG, for seamless interoperation
		// between SREG10/OpenID1 and SREG11/OpenID2
		if (SRegMessage.OPENID_NS_SREG.equals(typeUri)) {
			alias = "sreg";
		}

		extAliases.put(typeUri, alias);

		if (DEBUG) {
			LOG.debug("Adding extension; type URI: "
				+ typeUri
				+ " alias: "
				+ alias);
		}

		set("openid.ns." + alias, typeUri);

		for (Parameter param : extension.getParameters().getParameters()) {
			String paramName =
				param.getKey().length() > 0 ? "openid."
					+ alias
					+ "."
					+ param.getKey() : "openid." + alias;

			set(paramName, param.getValue());
		}

		if (this instanceof AuthSuccess
			&& ((AuthSuccess) this).getSignExtensions().contains(extension)) {
			((AuthSuccess) this).buildSignedList();
		}
	}

	/**
	 * Retrieves the parameters associated with a protocol extension, specified
	 * by the given extension type URI.
	 * <p>
	 * The "openid.ns.<extension_alias>" parameter is NOT included in the
	 * returned list. Also, the returned parameter names will have the "openid.<extension_alias>."
	 * prefix removed.
	 * 
	 * @param extensionTypeUri
	 *            The type URI that identifies the extension
	 * @return A ParameterList with all parameters associated with the specified
	 *         extension
	 */
	private ParameterList getExtensionParams(String extensionTypeUri) {
		ParameterList extension = new ParameterList();

		if (hasExtension(extensionTypeUri)) {
			String extensionAlias = getExtensionAlias(extensionTypeUri);

			for (Parameter param : getParameters()) {
				String paramName = null;

				if (param.getKey().startsWith("openid." + extensionAlias + ".")) {
					paramName =
						param.getKey().substring(8 + extensionAlias.length());
				}

				if (param.getKey().equals("openid." + extensionAlias)) {
					paramName = "";
				}

				if (paramName != null) {
					extension.set(new Parameter(paramName, param.getValue()));
				}
			}
		}

		return extension;
	}

	/**
	 * Gets a MessageExtension for the specified Type URI if an implementation
	 * is available, or null otherwise.
	 * <p>
	 * The returned object will contain the parameters from the message
	 * belonging to the specified extension.
	 * 
	 * @param typeUri
	 *            The Type URI that identifies a extension.
	 */
	public MessageExtension getExtension(String typeUri)
			throws MessageException {
		if (!extesion.containsKey(typeUri)) {
			if (hasExtensionFactory(typeUri)) {
				MessageExtensionFactory extensionFactory =
					getExtensionFactory(typeUri);

				String mode = getParameterValue("openid.mode");

				MessageExtension extension =
					extensionFactory.getExtension(
						getExtensionParams(typeUri),
						mode.startsWith("checkid_"));

				extesion.put(typeUri, extension);
			} else {
				throw new MessageException("Cannot instantiate extension: "
					+ typeUri);
			}
		}

		if (DEBUG) {
			LOG.debug("Extracting " + typeUri + " extension from message...");
		}

		return extesion.get(typeUri);
	}
}
