/* 
 * Copyright 2014 by AVM GmbH <info@avm.de>
 *
 * This software contains free software; you can redistribute it and/or modify 
 * it under the terms of the GNU General Public License ("License") as 
 * published by the Free Software Foundation  (version 3 of the License). 
 * This software is distributed in the hope that it will be useful, but 
 * WITHOUT ANY WARRANTY; without even the implied warranty of 
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the copy of the 
 * License you received along with this software for more details.
 */

package de.avm.android.security;

import java.math.BigInteger;
import java.security.KeyPairGenerator;
import java.security.KeyStore;
import java.security.SecureRandom;
import java.security.spec.AlgorithmParameterSpec;
import java.util.Calendar;
import java.util.Date;
import java.util.GregorianCalendar;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.KeyGenerator;
import javax.crypto.SecretKey;
import javax.crypto.spec.IvParameterSpec;
import javax.security.auth.x500.X500Principal;

import android.annotation.SuppressLint;
import android.content.Context;
import android.text.TextUtils;
import android.util.Base64;

@SuppressLint({ "InlinedApi", "NewApi" })
public class CipherWrapperAks extends CipherWrapper 
{
	private static final byte VERSION = 1;
	private static final int AKS_KEY_LEN = 2048;
    private static final int SESSION_KEY_LEN = 256;
    private static final String SESSION_KEY_ALGORITHM = "AES";
    private static final String CIPHER_SESSIONKEY = "RSA/ECB/PKCS1Padding";
    private static final String CIPHER_ENCRYPTION = "AES/CBC/PKCS5Padding";
    private static final String PLAIN_CHARSET = "UTF-8";

    private KeyStore.PrivateKeyEntry mKeyEntry;
	
    @Override
    public Type getType()
    {
    	return Type.AKS;
    }

    protected CipherWrapperAks(Context context, String alias) throws Exception
	{
		super(context, alias);
		
		if (context == null)
			throw new IllegalArgumentException("Argument context must not be null.");
		if (TextUtils.isEmpty(alias))
			throw new IllegalArgumentException("Argument alias must not be null or empty.");
		
        final KeyStore keyStore = KeyStore.getInstance("AndroidKeyStore");
        keyStore.load(null);
        
        if (!keyStore.containsAlias(alias)) generateKeyPair(context, alias);

        // Even if we just generated the key, always read it back to ensure we
        // can read it successfully.
        mKeyEntry = (KeyStore.PrivateKeyEntry) keyStore.getEntry(alias, null);
	}

	private void generateKeyPair(Context context, String alias) throws Exception
	{
        final Calendar startDate = new GregorianCalendar();
        final Calendar endDate = new GregorianCalendar();
        endDate.add(Calendar.YEAR, 100);

        final AlgorithmParameterSpec spec = getKeyPairGeneratorSpec(context, alias,
        		startDate, endDate);
        // RSA 2048
        final KeyPairGenerator gen = KeyPairGenerator.getInstance("RSA", "AndroidKeyStore");
        gen.initialize(spec);
        gen.generateKeyPair();
    }

    private static final String GEN_SPEC_CLASS = "android.security.KeyPairGeneratorSpec"; 
    
    /**
     * Get rid of this method, if target SDK is >= 18
     */
    private AlgorithmParameterSpec getKeyPairGeneratorSpec(Context context, String alias,
    		Calendar startDate, Calendar endDate) throws Exception
    {
		Class<?>[] classes = Class.forName(GEN_SPEC_CLASS).getDeclaredClasses();
		Class<?> klass = null;
		for (Class<?> nested : classes)
			if (nested.getSimpleName().equals("Builder"))
			{
				klass = nested;
				break;
			}
		if (klass == null)
			throw new ClassNotFoundException("Didn't find class \"" +
					GEN_SPEC_CLASS + ".Builder\"");
		Object builder = klass.getDeclaredConstructor(new Class<?>[] { Context.class } )
				.newInstance(new Object[] { context });

		klass.getMethod("setAlias", new Class<?>[] { String.class })
				.invoke(builder, new Object[] { alias });
		klass.getMethod("setSubject", new Class<?>[] { X500Principal.class })
				.invoke(builder, new Object[] { new X500Principal("CN=" + alias) });
		klass.getMethod("setSerialNumber", new Class<?>[] { BigInteger.class })
				.invoke(builder, new Object[] { BigInteger.ONE });
		klass.getMethod("setStartDate", new Class<?>[] { Date.class })
				.invoke(builder, new Object[] { startDate.getTime() });
		klass.getMethod("setEndDate", new Class<?>[] { Date.class })
				.invoke(builder, new Object[] { endDate.getTime() });
		try
		{
			klass.getMethod("setKeySize", new Class<?>[] { int.class })
					.invoke(builder, new Object[] { AKS_KEY_LEN });
		}
		catch(NoSuchMethodException e)
		{
			// default is used if setKeySize not implemented
		}

		return AlgorithmParameterSpec.class.cast(klass.getMethod("build", (Class[])null)
			.invoke(builder, (Object[])null));
    }
    
    @Override
    public String encrypt (String plain) throws Exception
    {
        if (!TextUtils.isEmpty(plain))
        {
        	// session key
            SecureRandom secureRandom = new SecureRandom();
            // Do *not* seed secureRandom! Automatically seeded from system entropy.
            KeyGenerator keyGenerator = KeyGenerator.getInstance(SESSION_KEY_ALGORITHM);
            keyGenerator.init(SESSION_KEY_LEN, secureRandom);
        	SecretKey sessionKey = keyGenerator.generateKey();
        	
        	Container container = new Container(VERSION);
        	
        	// encrypt plain
        	Cipher cipher = Cipher.getInstance(CIPHER_ENCRYPTION);
        	container.mParams = new byte[cipher.getBlockSize()];
        	secureRandom.nextBytes(container.mParams);
        	cipher.init(Cipher.ENCRYPT_MODE, sessionKey, new IvParameterSpec(container.mParams));
        	container.mData = cipher.doFinal(plain.getBytes(PLAIN_CHARSET));
        	
        	// encrypt session key
        	cipher = Cipher.getInstance(CIPHER_SESSIONKEY);
        	cipher.init(Cipher.WRAP_MODE, mKeyEntry.getCertificate().getPublicKey());
        	container.mKey = cipher.wrap(sessionKey);
        	
            return Base64.encodeToString(container.getBytes(), Base64.DEFAULT);
        }
        return "";
    }

    @Override
    public String decrypt (String encrypted) throws Exception
    {
        if (!TextUtils.isEmpty(encrypted))
        {
        	Container container = new Container(
        			Base64.decode(encrypted, Base64.DEFAULT));
        	
        	// decrypt session key
        	Cipher cipher = Cipher.getInstance(CIPHER_SESSIONKEY);
        	cipher.init(Cipher.UNWRAP_MODE, mKeyEntry.getPrivateKey());
        	SecretKey sessionKey = (SecretKey)cipher.unwrap(container.mKey,
        			SESSION_KEY_ALGORITHM, Cipher.SECRET_KEY);
        	
        	// decrypt plain
        	cipher = Cipher.getInstance(CIPHER_ENCRYPTION);
        	cipher.init(Cipher.DECRYPT_MODE, sessionKey, new IvParameterSpec(container.mParams));
        	return new String(cipher.doFinal(container.mData), PLAIN_CHARSET);
        }
        return "";
    }
    
    private static class Container
    {
    	public byte mVersion = 0;
    	public byte[] mKey = null;
    	public byte[] mParams = null;
    	public byte[] mData = null;
    	
    	public Container(byte version)
    	{
    		mVersion = version;
    	}    	
    	
    	public Container(byte[] container)
    	{
    		try
    		{
	    		mVersion = container[0];
	        	int next = 1;
	    		switch (mVersion)
	    		{
	    			case VERSION:
	    	        	if (container.length <= (1 + SESSION_KEY_LEN))
	    	        		throw new IllegalArgumentException("Content is to short");
	    	        	
	    	        	mKey = new byte[SESSION_KEY_LEN];
	    	        	System.arraycopy(container, next, mKey, 0, SESSION_KEY_LEN);
	    	    		next += SESSION_KEY_LEN;
	    	    		
	    	    		int lenParams = (int)container[1 + SESSION_KEY_LEN];
	    	    		next++;
	    	    		mParams = new byte[lenParams];
	    	    		System.arraycopy(container, next, mParams, 0, lenParams);
	    	    		next += lenParams;

	    	    		mData = new byte[container.length - next];
	    	    		System.arraycopy(container, next, mData, 0, mData.length);
	    				break;
	    			
	    			default:
		    			throw new IllegalArgumentException("Invalid content version");
	    		}
    		}
    		catch(ArrayIndexOutOfBoundsException e)
    		{
    			throw new IllegalArgumentException("Content is corrupted");
    		}
    	}
    	
    	public byte[] getBytes() throws IllegalBlockSizeException
    	{
    		byte[] result = null;
    		switch (mVersion)
    		{
    			case VERSION:
    	    		result = new byte[1 +
    	    		                  mKey.length +
    	    		                  1 + mParams.length +
    	    		                  mData.length];
    	    		int next = 0;
    	    		result[next++] = mVersion;
    	    		System.arraycopy(mKey, 0, result, next, mKey.length);
    	    		next += mKey.length;
    	    		if (mParams.length > 255)
    	    			throw new IllegalBlockSizeException("Cipher params must not be larger than 255");
    	    		result[next++] = (byte)mParams.length;
    	    		System.arraycopy(mParams, 0, result, next, mParams.length);
    	    		next += mParams.length;
    	    		System.arraycopy(mData, 0, result, next, mData.length);
    				break;
    		}
    		
    		return result;
    	}
    }
}
