Thursday, February 28, 2008

Java Encryption

I was working on a project that required some encryption. So, I had some time to play around with the javax.crypto api. Here is a small utility which plays around with text files before encryption.
package org.foo.bar; 
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileFilter;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.StringWriter;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;

import javax.crypto.Cipher;
import javax.crypto.CipherOutputStream;
import javax.crypto.KeyGenerator;
import javax.crypto.SecretKey;


public class FcUtil {
  private static final String UTF8 = "utf-8";
  private static SecretKey getKey(String alg, String pass) {
    try {
      KeyGenerator keyGen = KeyGenerator.getInstance("Rijndael");
      keyGen.init(new SecureRandom(pass.getBytes(UTF8)));
      return keyGen.generateKey();
    } catch (Exception e) {
      throw new RuntimeException(e);
    }
  }
  private static Cipher getCipher(String alg, String pass) {
    try {
      Cipher cipher = Cipher.getInstance(alg);
      cipher.init(Cipher.ENCRYPT_MODE, getKey(alg, pass));
      return cipher;
    } catch (Exception e) {
      throw new RuntimeException(e);
    }
  }
  private static void encrypt(Cipher cipher, InputStream is, OutputStream os) {
    try {
      byte[] buffer = new byte[1024 * 1024];
      CipherOutputStream cos = new CipherOutputStream(os, cipher);
      int read = -1;
      int count = 10;
      while ( (read = is.read(buffer) ) > 0 ) {
        cos.write(buffer, 0, read);
        if ( --count < 0 ) {
          count = 10;
          cos.flush();
        }
      } // endwhile
      is.close();
      cos.close();
    } catch (Exception e) {
      throw new RuntimeException(e);
    }
  }
  private static String rev(String str) {
    char[] chars = str.toCharArray();
    int l = chars.length, n = l >> 1;
    char c;
    for (int i = 0; i < n; i++ ) {
      c = chars[(b = l-i)];
      chars[b] = chars[i];
      chars[i] = c;
    }
    return new String(chars);
  }
  private static byte[] processText(File file) {
    try {
      BufferedReader br = new BufferedReader(
          new InputStreamReader(new FileInputStream(file), UTF8));
      ArrayList<String> list = new ArrayList<String>();
      StringBuilder b = new StringBuilder();
      String buf = null;
      while ( (buf = br.readLine()) != null ) {
        b.delete(0, b.length());
        list.add(b.insert(0, buf).reverse().toString());
      }
      Collections.reverse(list);
      br.close();
      int count = 10;
      StringWriter sw = new StringWriter();
      for (String s : list) {
        sw.write(s);
        sw.write('\n');
        if ( --count < 0 ) {
          count = 10;
          sw.flush();
        }
      }
      sw.close();
      return sw.getBuffer().toString().getBytes(UTF8);
    } catch (Exception e) {
      throw new RuntimeException(e);
    }
  }
  private static ByteArrayInputStream processText(InputStream is) {
    try {
      BufferedReader br = new BufferedReader(
          new InputStreamReader(is, UTF8));
      ArrayList<String> list = new ArrayList<String>();
      StringBuilder b = new StringBuilder();
      String buf = null;
      while ( (buf = br.readLine()) != null ) {
        b.delete(0, b.length());
        list.add(b.insert(0, buf).reverse().toString());
      }
      Collections.reverse(list);
      br.close();
      int count = 10;
      StringWriter sw = new StringWriter();
      for (String s : list) {
        sw.write(s);
        sw.write('\n');
        if ( --count < 0 ) {
          count = 10;
          sw.flush();
        }
      }
      sw.close();
      return new ByteArrayInputStream(sw.getBuffer().toString().getBytes(UTF8));
    } catch (Exception e) {
      throw new RuntimeException(e);
    }
  }
  private static ByteArrayInputStream readFully(File file) {
    try {
      FileInputStream fis = new FileInputStream(file);
      ByteArrayOutputStream baos = new ByteArrayOutputStream();
      byte[] buffer = new byte[1024 * 1024];
      int read = -1;
      while ( (read = fis.read(buffer)) > 0 ) {
        baos.write(buffer, 0, read);
      }
      fis.close();
      baos.close();
      return new ByteArrayInputStream(baos.toByteArray());
    } catch (Exception e) {
      throw new RuntimeException(e);
    }
  }
  private static void transferFully(InputStream is, OutputStream os) {
    try {
      byte[] buffer = new byte[1024 * 1024];
      int read = -1;
      int count = 10;
      while ( (read = is.read(buffer)) > 0 ) {
        os.write(buffer, 0, read);
        if ( --count < 0 ) {
          count = 10;
          os.flush();
        }
      }
      is.close();
      os.close();
    } catch (Exception e) {
      throw new RuntimeException(e);
    }
  }
  private static void process(Cipher cipher, File file) {
    if ( file.isDirectory() ) {
      return;
    }
    HashSet<String> textExt = new HashSet<String>() {{
      add("txt");
      add("java");
      add("properties");
      add("xml");
      add("bat");
      add("sh");
      add("jsp");
      add("html");
      add("tld");
      add("js");
      add("css");
      add("dtd");
      add("xsd");
      add("sql");
      add("ddl");
    }};
    try {
      String ext = file.getName().substring(file.getName().lastIndexOf('.')+1).toLowerCase();
      if ( textExt.contains(ext) ) {

        // encrypt
        ByteArrayInputStream bis = new ByteArrayInputStream(processText(file));
        encrypt(cipher, bis, new FileOutputStream(file));

        // decrypt
//        ByteArrayOutputStream baos = new ByteArrayOutputStream();
//        encrypt(cipher, readFully(file), baos);
//        transferFully(processText(new ByteArrayInputStream(
//            baos.toByteArray())), new FileOutputStream(file));
      } else {
        // encrypt & decrypt
        ByteArrayInputStream bis = readFully(file);
        encrypt(cipher, bis, new FileOutputStream(file));
      }
    } catch (Exception e) {
      throw new RuntimeException(e);
    }    
  }
  private static void handleDir(Cipher cipher, File dir) {
    long start = System.currentTimeMillis();
    FileFilter fileFilter = new FileFilter() {
      HashSet<String> encExt = new HashSet<String>() {{
        add("txt");
        add("java");
        add("properties");
        add("xml");
        add("bat");
        add("sh");
        add("jsp");
        add("html");
        add("tld");
        add("js");
        add("css");
        add("dtd");
        add("xsd");
        add("sql");
        add("ddl");
        
        add("doc");
        add("xls");
        add("ppt");
        add("pps");
        add("pdf");
        add("pst");
      }};
      public boolean accept(File pathname) {
        String ext = pathname.getName().substring(pathname.getName().lastIndexOf('.')+1).toLowerCase();
        return pathname.isFile() && encExt.contains(ext);
      }
    };
    FileFilter dirFilter = new FileFilter() {
      HashSet<String> exclude = new HashSet<String>() {{
        add(".svn");
        add("cvs");
        add(".hg");
      }};
      public boolean accept(File pathname) {
        return pathname.isDirectory() && !exclude.contains(pathname.getName().toLowerCase()) && pathname.getName().indexOf('.') != 0;
      }
      
    };
    File[] files = dir.listFiles(fileFilter);
    for (File file : files) {
//      System.out.println(file.getAbsolutePath());
      process(cipher, file);
    }
    File[] dirs = dir.listFiles(dirFilter);
    handleDirs(cipher, dirs);
//    System.out.println(dir.getAbsolutePath() + "[" + (System.currentTimeMillis() - start) + "ms]");
  }
  private static void handleDirs(Cipher cipher, File[] dirs) {
    for (File dir : dirs) {
      handleDir(cipher, dir);
    }
  }
  public static void main(String[] args) {
    final String alg = "Rijndael";
    if ( args.length != 2 ) {
      System.out.println("error");
      return;
    }
    final String key = args[0];
    
    try {
      Cipher cipher = getCipher(alg, key);
//      cipher.init(Cipher.DECRYPT_MODE, getKey(alg, key));
      final File file = new File(args[1]);
      handleDir(cipher, file);
    } catch (Exception e) {
      e.printStackTrace();
    }
  }
}

No comments: