AES implementation in Python

After implementing DES, the next obvious challenge was AES. I was expecting AES code to be simpler to write than DES’ because AES was designed to be implemented in hardware or software, while DES design was geared towards hardware. This time, however, I decided to write an object-oriented API supporting the three different key sizes AES inherited from Rijndael (128-, 192- and 256-bit). In addition, besides the ECB (Electronic Code Book) basic operation mode, this implementation also supports CBC (Cipher Block Chaining) mode.

#!/usr/bin/python3
#
# Author: Joao H de A Franco (jhafranco@acm.org)
#
# Description: AES implementation in Python 3
# (sundAES)
#
# Date: 2013-06-02 (version 1.1)
#       2012-01-16 (version 1.0)
#
# License: Attribution-NonCommercial-ShareAlike 3.0 Unported
# (CC BY-NC-SA 3.0)
#===========================================================
import sys
from itertools import repeat
from functools import reduce
from copy import copy

__all__ = ["setKey","encrypt","decrypt"]

def memoize(func):
"""Memoization function"""
memo = {}
def helper(x):
if x not in memo:
memo[x] = func(x)
return memo[x]
return helper

def mult(p1,p2):
"""Multiply two polynomials in GF(2^8)/x^8+x^4+x^3+x+1"""
p = 0
while p2:
if p2&0x01:
p ^= p1
p1 <<= 1
if p1&0x100:
p1 ^= 0x1b
p2 >>= 1
return p&0xff

# Auxiliary one-parameter functions defined for memoization
# (to speed up multiplication in GF(2^8))

@memoize
def x2(y):
"""Multiplication by 2"""
return mult(2,y)

@memoize
def x3(y):
"""Multiplication by 3"""
return mult(3,y)

@memoize
def x9(y):
"""Multiplication by 9"""
return mult(9,y)

@memoize
def x11(y):
"""Multiplication by 11"""
return mult(11,y)

@memoize
def x13(y):
"""Multiplication by 13"""
return mult(13,y)

@memoize
def x14(y):
"""Multiplication by 14"""
return mult(14,y)

class AES:
"""Class definition for AES objects"""
keySizeTable = {"SIZE_128":16,
"SIZE_192":24,
"SIZE_256":32}
wordSizeTable = {"SIZE_128":44,
"SIZE_192":52,
"SIZE_256":60}
numberOfRoundsTable = {"SIZE_128":10,
"SIZE_192":12,
"SIZE_256":14}
cipherModeTable = {"MODE_ECB":1,
"MODE_CBC":2}
paddingTable = {"NoPadding":0,
"PKCS7Padding":1}
# S-Box
sBox = (0x63,0x7c,0x77,0x7b,0xf2,0x6b,0x6f,0xc5,
0x30,0x01,0x67,0x2b,0xfe,0xd7,0xab,0x76,
0xca,0x82,0xc9,0x7d,0xfa,0x59,0x47,0xf0,
0xad,0xd4,0xa2,0xaf,0x9c,0xa4,0x72,0xc0,
0xb7,0xfd,0x93,0x26,0x36,0x3f,0xf7,0xcc,
0x34,0xa5,0xe5,0xf1,0x71,0xd8,0x31,0x15,
0x04,0xc7,0x23,0xc3,0x18,0x96,0x05,0x9a,
0x07,0x12,0x80,0xe2,0xeb,0x27,0xb2,0x75,
0x09,0x83,0x2c,0x1a,0x1b,0x6e,0x5a,0xa0,
0x52,0x3b,0xd6,0xb3,0x29,0xe3,0x2f,0x84,
0x53,0xd1,0x00,0xed,0x20,0xfc,0xb1,0x5b,
0x6a,0xcb,0xbe,0x39,0x4a,0x4c,0x58,0xcf,
0xd0,0xef,0xaa,0xfb,0x43,0x4d,0x33,0x85,
0x45,0xf9,0x02,0x7f,0x50,0x3c,0x9f,0xa8,
0x51,0xa3,0x40,0x8f,0x92,0x9d,0x38,0xf5,
0xbc,0xb6,0xda,0x21,0x10,0xff,0xf3,0xd2,
0xcd,0x0c,0x13,0xec,0x5f,0x97,0x44,0x17,
0xc4,0xa7,0x7e,0x3d,0x64,0x5d,0x19,0x73,
0x60,0x81,0x4f,0xdc,0x22,0x2a,0x90,0x88,
0x46,0xee,0xb8,0x14,0xde,0x5e,0x0b,0xdb,
0xe0,0x32,0x3a,0x0a,0x49,0x06,0x24,0x5c,
0xc2,0xd3,0xac,0x62,0x91,0x95,0xe4,0x79,
0xe7,0xc8,0x37,0x6d,0x8d,0xd5,0x4e,0xa9,
0x6c,0x56,0xf4,0xea,0x65,0x7a,0xae,0x08,
0xba,0x78,0x25,0x2e,0x1c,0xa6,0xb4,0xc6,
0xe8,0xdd,0x74,0x1f,0x4b,0xbd,0x8b,0x8a,
0x70,0x3e,0xb5,0x66,0x48,0x03,0xf6,0x0e,
0x61,0x35,0x57,0xb9,0x86,0xc1,0x1d,0x9e,
0xe1,0xf8,0x98,0x11,0x69,0xd9,0x8e,0x94,
0x9b,0x1e,0x87,0xe9,0xce,0x55,0x28,0xdf,
0x8c,0xa1,0x89,0x0d,0xbf,0xe6,0x42,0x68,
0x41,0x99,0x2d,0x0f,0xb0,0x54,0xbb,0x16)
# Inverse S-Box
invSBox = (0x52,0x09,0x6a,0xd5,0x30,0x36,0xa5,0x38,
0xbf,0x40,0xa3,0x9e,0x81,0xf3,0xd7,0xfb,
0x7c,0xe3,0x39,0x82,0x9b,0x2f,0xff,0x87,
0x34,0x8e,0x43,0x44,0xc4,0xde,0xe9,0xcb,
0x54,0x7b,0x94,0x32,0xa6,0xc2,0x23,0x3d,
0xee,0x4c,0x95,0x0b,0x42,0xfa,0xc3,0x4e,
0x08,0x2e,0xa1,0x66,0x28,0xd9,0x24,0xb2,
0x76,0x5b,0xa2,0x49,0x6d,0x8b,0xd1,0x25,
0x72,0xf8,0xf6,0x64,0x86,0x68,0x98,0x16,
0xd4,0xa4,0x5c,0xcc,0x5d,0x65,0xb6,0x92,
0x6c,0x70,0x48,0x50,0xfd,0xed,0xb9,0xda,
0x5e,0x15,0x46,0x57,0xa7,0x8d,0x9d,0x84,
0x90,0xd8,0xab,0x00,0x8c,0xbc,0xd3,0x0a,
0xf7,0xe4,0x58,0x05,0xb8,0xb3,0x45,0x06,
0xd0,0x2c,0x1e,0x8f,0xca,0x3f,0x0f,0x02,
0xc1,0xaf,0xbd,0x03,0x01,0x13,0x8a,0x6b,
0x3a,0x91,0x11,0x41,0x4f,0x67,0xdc,0xea,
0x97,0xf2,0xcf,0xce,0xf0,0xb4,0xe6,0x73,
0x96,0xac,0x74,0x22,0xe7,0xad,0x35,0x85,
0xe2,0xf9,0x37,0xe8,0x1c,0x75,0xdf,0x6e,
0x47,0xf1,0x1a,0x71,0x1d,0x29,0xc5,0x89,
0x6f,0xb7,0x62,0x0e,0xaa,0x18,0xbe,0x1b,
0xfc,0x56,0x3e,0x4b,0xc6,0xd2,0x79,0x20,
0x9a,0xdb,0xc0,0xfe,0x78,0xcd,0x5a,0xf4,
0x1f,0xdd,0xa8,0x33,0x88,0x07,0xc7,0x31,
0xb1,0x12,0x10,0x59,0x27,0x80,0xec,0x5f,
0x60,0x51,0x7f,0xa9,0x19,0xb5,0x4a,0x0d,
0x2d,0xe5,0x7a,0x9f,0x93,0xc9,0x9c,0xef,
0xa0,0xe0,0x3b,0x4d,0xae,0x2a,0xf5,0xb0,
0xc8,0xeb,0xbb,0x3c,0x83,0x53,0x99,0x61,
0x17,0x2b,0x04,0x7e,0xba,0x77,0xd6,0x26,
0xe1,0x69,0x14,0x63,0x55,0x21,0x0c,0x7d)

# Instance variables
wordSize = None
w = [None]*60 # Round subkeys list
keyDefined = None # Key definition flag
numberOfRounds = None
cipherMode = None
padding = None # Padding scheme
ivEncrypt = None # Initialization
ivDecrypt = None # vectors

def __init__(self,mode,padding = "NoPadding"):
"""Create a new instance of an AES object"""
try:
assert mode in AES.cipherModeTable
except AssertionError:
print("Cipher mode not supported:",mode)
sys.exit("ValueError")
self.cipherMode = mode
try:
assert padding in AES.paddingTable
except AssertionError:
print("Padding scheme not supported:",padding)
sys.exit(ValueError)
self.padding = padding
self.keyDefined = False

def intToList(self,number):
"""Convert an 16-byte number into a 16-element list"""
return [(number>>i)&0xff for i in reversed(range(0,128,8))]

def intToList2(self,number):
"""Converts an integer into one (or more) 16-element list"""
lst = []
while number:
lst.append(number&0xff)
number >>= 8
m = len(lst)%16
if m == 0 and len(lst) != 0:
return lst[::-1]
else:
return list(bytes(16-m)) + lst[::-1]

def listToInt(self,lst):
"""Convert a list into a number"""
return reduce(lambda x,y:(x<<8)+y,lst)

def wordToState(self,wordList):
"""Convert list of 4 words into a 16-element state list"""
return [(wordList[i]>>j)&0xff
for j in reversed(range(0,32,8)) for i in range(4)]

def listToState(self,list):
"""Convert a 16-element list into a 16-element state list"""
return [list[i+j] for j in range(4) for i in range(0,16,4)]

stateToList = listToState # this function is an involution

def subBytes(self,state):
"""SubBytes transformation"""
return [AES.sBox[e] for e in state]

def invSubBytes(self,state):
"""Inverse SubBytes transformation"""
return [AES.invSBox[e] for e in state]

def shiftRows(self,s):
"""ShiftRows transformation"""
return s[:4]+s[5:8]+s[4:5]+s[10:12]+s[8:10]+s[15:]+s[12:15]

def invShiftRows(self,s):
"""Inverse ShiftRows transformation"""
return s[:4]+s[7:8]+s[4:7]+s[10:12]+s[8:10]+s[13:]+s[12:13]

def mixColumns(self,s):
"""MixColumns transformation"""
return [x2(s[i])^x3(s[i+4])^ s[i+8] ^ s[i+12] for i in range(4)]+ \
[ s[i] ^x2(s[i+4])^x3(s[i+8])^ s[i+12] for i in range(4)]+ \
[ s[i] ^ s[i+4] ^x2(s[i+8])^x3(s[i+12]) for i in range(4)]+ \
[x3(s[i])^ s[i+4] ^ s[i+8] ^x2(s[i+12]) for i in range(4)]

def invMixColumns(self,s):
"""Inverse MixColumns transformation"""
return [x14(s[i])^x11(s[i+4])^x13(s[i+8])^ x9(s[i+12]) for i in range(4)]+ \
[ x9(s[i])^x14(s[i+4])^x11(s[i+8])^x13(s[i+12]) for i in range(4)]+ \
[x13(s[i])^ x9(s[i+4])^x14(s[i+8])^x11(s[i+12]) for i in range(4)]+ \
[x11(s[i])^x13(s[i+4])^ x9(s[i+8])^x14(s[i+12]) for i in range(4)]

def addRoundKey (self,subkey,state):
"""AddRoundKey transformation"""
return [i^j for i,j in zip(subkey,state)]

xorLists = addRoundKey

def rotWord(self,number):
"""Rotate subkey left"""
return (((number&0xff000000)>>24) +
((number&0xff0000)<<8) +
((number&0xff00)<<8) +
((number&0xff)<<8))

def subWord(self,key):
"""Substitute subkeys bytes using S-box"""
return ((AES.sBox[(key>>24)&0xff]<<24) +
(AES.sBox[(key>>16)&0xff]<<16) +
(AES.sBox[(key>>8)&0xff]<<8) +
AES.sBox[key&0xff])

def setKey(self,keySize,key,iv = None):
"""KeyExpansion transformation"""
rcon = (0x00,0x01,0x02,0x04,0x08,0x10,0x20,0x40,0x80,0x1B,0x36)
try:
assert keySize in AES.keySizeTable
except AssertionError:
print("Key size identifier not valid")
sys.exit("ValueError")
try:
assert isinstance(key,int)
except AssertionError:
print("Invalid key")
sys.exit("ValueError")
klen = len("{:02x}".format(key))//2
try:
assert klen <= AES.keySizeTable[keySize]
except AssertionError:
print("Key size mismatch")
sys.exit("ValueError")
try:
assert ((self.cipherMode == "MODE_CBC" and isinstance(iv,int)) or
self.cipherMode == "MODE_ECB")
except AssertionError:
print("IV is mandatory for CBC mode")
sys.exit(ValueError)

if self.cipherMode == "MODE_CBC":
temp = self.intToList(iv)
self.ivEncrypt = copy(temp)
self.ivDecrypt = copy(temp)
nr = AES.numberOfRoundsTable[keySize]
self.numberOfRounds = nr
self.wordSize = AES.wordSizeTable[keySize]
if nr == 10:
nk = 4
keyList = self.intToList(key)
elif nr == 12:
nk = 6
keyList = self.intToList(key>>64) + \
(self.intToList(key&int("ff"*32,16)))[8:]
else:
nk = 8
keyList = self.intToList(key>>128) + \
self.intToList(key&int("ff"*64,16))
for index in range(nk):
self.w[index] = (keyList[4*index]<<24) + \
(keyList[4*index+1]<<16) + \
(keyList[4*index+2]<<8) +\
keyList[4*index+3]
for index in range(nk,self.wordSize):
temp = self.w[index - 1]
if index % nk == 0:
temp = (self.subWord(self.rotWord(temp)) ^
rcon[index//nk]<<24)
elif self.numberOfRounds == 14 and index%nk == 4:
temp = self.subWord(temp)
self.w[index] = self.w[index-nk]^temp
self.keyDefined = True
return

def getKey(self,operation):
"""Return next round subkey for encryption or decryption"""
if operation == "encryption":
for i in range(0,self.wordSize,4):
yield self.wordToState(self.w[i:i+4])
else: # operation = "decryption":
for i in reversed(range(0,self.wordSize,4)):
yield self.wordToState(self.w[i:i+4])

def encryptBlock(self,plaintextBlock):
"""Encrypt a 16-byte block with key already defined"""
key = self.getKey("encryption")
state = self.listToState(plaintextBlock)
state = self.addRoundKey(next(key),state)
for _ in repeat(None,self.numberOfRounds - 1):
state = self.subBytes(state)
state = self.shiftRows(state)
state = self.mixColumns(state)
state = self.addRoundKey(next(key),state)
state = self.subBytes(state)
state = self.shiftRows(state)
state = self.addRoundKey(next(key),state)
return self.stateToList(state)

def decryptBlock(self,ciphertextBlock):
"""Decrypt a 16-byte block with key already defined"""
key = self.getKey("decryption")
state = self.listToState(ciphertextBlock)
state = self.addRoundKey(next(key),state)
for _ in repeat(None,self.numberOfRounds - 1):
state = self.invShiftRows(state)
state = self.invSubBytes(state)
state = self.addRoundKey(next(key),state)
state = self.invMixColumns(state)
state = self.invShiftRows(state)
state = self.invSubBytes(state)
state = self.addRoundKey(next(key),state)
return self.stateToList(state)

def padData(self,data):
"""Add PKCS7 padding to plaintext (or just add bytes to fill a block)"""
paddingLength = 16-(len(data)%16)
if self.padding == "NoPadding":
paddingLength %= 16
if type(data) is bytes:
return data+bytes(list([paddingLength]*paddingLength))
else:
return [ord(s) for s in data]+[paddingLength]*paddingLength

def unpadData(self,byteList):
"""Remove PKCS7 padding (if present) from plaintext"""
if self.padding == "PKCS7Padding":
return "".join(chr(e) for e in byteList[:-byteList[-1]])
else:
return "".join(chr(e) for e in byteList)

def encrypt(self,input):
"""Encrypt plaintext passed as a string or as an integer"""
try:
assert self.keyDefined
except AssertionError:
print("Key not defined")
sys.exit("ValueError")

if type(input) is int:
inList = self.intToList2(input)
else:
inList = self.padData(input)
outList = []
if self.cipherMode == "MODE_CBC":
outBlock = self.ivEncrypt
for i in range(0,len(inList),16):
auxList = self.xorLists(outBlock,inList[i:i+16])
outBlock = self.encryptBlock(auxList)
outList += outBlock
self.ivEncrypt = outBlock
else:
for i in range(0,len(inList),16):
outList += self.encryptBlock(inList[i:i+16])
if type(input) is int:
return self.listToInt(outList)
else:
return outList

def decrypt(self,input):
"""Decrypt ciphertext passed as a string or as an integer"""
try:
assert self.keyDefined
except AssertionError:
print("Key not defined")
sys.exit("ValueError")
if type(input) is int:
inList = self.intToList2(input)
else:
inList = input
outList = []
if self.cipherMode == "MODE_CBC":
oldInBlock = self.ivDecrypt
for i in range(0,len(inList),16):
newInBlock = inList[i:i+16]
auxList = self.decryptBlock(newInBlock)
outList += self.xorLists(oldInBlock,auxList)
oldInBlock = newInBlock
self.ivDecrypt = oldInBlock
else:
for i in range(0,len(inList),16):
outList += self.decryptBlock(inList[i:i+16])
if type(input) is int:
return self.listToInt(outList)
else:
return self.unpadData(outList)

The code below informally verify the correctness of the implementation with the help of the test vectors described in NIST document SP800-38A, Recommendation for Block Cipher Modes of Operation – Methods and Techniques:

import pyAES

def intToList(number):
    """Convert a 16-byte number into a 16-element list"""
    return [(number >> 120) & 0xff, (number >> 112) & 0xff,
            (number >> 104) & 0xff, (number >> 96)  & 0xff,
            (number >> 88)  & 0xff, (number >> 80)  & 0xff,
            (number >> 72)  & 0xff, (number >> 64)  & 0xff,
            (number >> 56)  & 0xff, (number >> 48)  & 0xff,
            (number >> 40)  & 0xff, (number >> 32)  & 0xff,
            (number >> 24)  & 0xff, (number >> 16)  & 0xff,
            (number >> 8)   & 0xff,  number & 0xff]

def intToText(hexNumber):
    """Convert a 16-byte number into a 16 char text string"""
    return "".join(chr(e) for e in intToList(hexNumber))

def checkTestVector1(mode, keySize, key, plaintext, ciphertext, iv = None):
    """Check test vectors for single block encryption and decryption"""
    success = True
    obj = pyAES.AES(mode)
    obj.setKey(keySize, key, iv)
    for i, (p, c) in enumerate(zip(plaintext, ciphertext)):
        p_text = intToText(p)
        c_text = intToList(c)
        code = "{0:3s}-{1:3s}".format(mode[5:], keySize[5:])
        try:
            assert obj.encrypt(p_text) == c_text
        except AssertionError:
            print(code, "encryption #{:d} failed".format(i))
            success = False
        try:
            assert obj.decrypt(c_text) == p_text
        except AssertionError:
            print(code, "decryption #{:d} failed".format(i))
            success = False
    if success:
        print(code, "encryption/decryption ok")
    return success

def checkTestVector2(mode, keySize, plaintext, key1, key2, iv1 = None, iv2 = None):
    obj1 = pyAES.AES(mode, "PKCS5Padding")
    obj1.setKey(keySize, key1, iv1)
    obj2 = pyAES.AES(mode, "PKCS5Padding")
    obj2.setKey(keySize, key2, iv2)
    ciphertext1 = obj1.encrypt(plaintext)
    ciphertext2 = obj2.encrypt(plaintext)
    plaintext1 = obj1.decrypt(ciphertext1)
    plaintext2 = obj2.decrypt(ciphertext2)
    code = "{0:3s}-{1:3s}".format(mode[5:], keySize[5:])
    try:
        assert plaintext1 == plaintext and plaintext2 == plaintext
    except AssertionError:
        print("Multi-block", code, "encryption/decryption failed")
        return False
    print("Multi-block", code, "encryption/decryption ok")
    return True

#===========================================================================

# Test vectors from NIST SP800-38A sections F.1.1 and F.1.2
def ECB_128_NOPAD():
    key = 0x2b7e151628aed2a6abf7158809cf4f3c
    plaintext =  [0x6bc1bee22e409f96e93d7e117393172a,
                  0xae2d8a571e03ac9c9eb76fac45af8e51,
                  0x30c81c46a35ce411e5fbc1191a0a52ef,
                  0xf69f2445df4f9b17ad2b417be66c3710]
    ciphertext = [0x3ad77bb40d7a3660a89ecaf32466ef97,
                  0xf5d3d58503b9699de785895a96fdbaaf,
                  0x43b1cd7f598ece23881b00e3ed030688,
                  0x7b0c785e27e8ad3f8223207104725dd4]
    return checkTestVector1("MODE_ECB", "SIZE_128", key, plaintext, ciphertext)

# Test vectors from NIST SP800-38A sections F.1.3 and F.1.4
def ECB_192_NOPAD():
    key = 0x8e73b0f7da0e6452c810f32b809079e562f8ead2522c6b7b
    plaintext =  [0x6bc1bee22e409f96e93d7e117393172a,
                  0xae2d8a571e03ac9c9eb76fac45af8e51,
                  0x30c81c46a35ce411e5fbc1191a0a52ef,
                  0xf69f2445df4f9b17ad2b417be66c3710]
    ciphertext = [0xbd334f1d6e45f25ff712a214571fa5cc,
                  0x974104846d0ad3ad7734ecb3ecee4eef,
                  0xef7afd2270e2e60adce0ba2face6444e,
                  0x9a4b41ba738d6c72fb16691603c18e0e]
    return checkTestVector1("MODE_ECB", "SIZE_192", key, plaintext, ciphertext)

# Test vectors from NIST SP800-38A sections F.1.5 and F.1.6
def ECB_256_NOPAD():
    key = 0x603deb1015ca71be2b73aef0857d77811f352c073b6108d72d9810a30914dff4
    plaintext =  [0x6bc1bee22e409f96e93d7e117393172a,
                  0xae2d8a571e03ac9c9eb76fac45af8e51,
                  0x30c81c46a35ce411e5fbc1191a0a52ef,
                  0xf69f2445df4f9b17ad2b417be66c3710]
    ciphertext = [0xf3eed1bdb5d2a03c064b5a7e3db181f8,
                  0x591ccb10d410ed26dc5ba74a31362870,
                  0xb6ed21b99ca6f4f9f153e7b1beafed1d,
                  0x23304b7a39f9f3ff067d8d8f9e24ecc7]
    return checkTestVector1("MODE_ECB", "SIZE_256", key, plaintext, ciphertext)

# Test vectors from NIST SP800-38A sections F.2.1 and F.2.2
def CBC_128_NOPAD():
    key = 0x2b7e151628aed2a6abf7158809cf4f3c
    iv = 0x000102030405060708090a0b0c0d0e0f
    plaintext =  [0x6bc1bee22e409f96e93d7e117393172a,
                   0xae2d8a571e03ac9c9eb76fac45af8e51,
                   0x30c81c46a35ce411e5fbc1191a0a52ef,
                   0xf69f2445df4f9b17ad2b417be66c3710]
    ciphertext = [0x7649abac8119b246cee98e9b12e9197d,
                  0x5086cb9b507219ee95db113a917678b2,
                  0x73bed6b8e3c1743b7116e69e22229516,
                  0x3ff1caa1681fac09120eca307586e1a7]
    return checkTestVector1("MODE_CBC", "SIZE_128", key, plaintext, ciphertext, iv)

# Test vectors from NIST SP800-38A sections F.2.3 and F.2.4
def CBC_192_NOPAD():
    key = 0x8e73b0f7da0e6452c810f32b809079e562f8ead2522c6b7b
    iv = 0x000102030405060708090a0b0c0d0e0f
    plaintext =  [0x6bc1bee22e409f96e93d7e117393172a,
                  0xae2d8a571e03ac9c9eb76fac45af8e51,
                  0x30c81c46a35ce411e5fbc1191a0a52ef,
                  0xf69f2445df4f9b17ad2b417be66c3710]
    ciphertext = [0x4f021db243bc633d7178183a9fa071e8,
                  0xb4d9ada9ad7dedf4e5e738763f69145a,
                  0x571b242012fb7ae07fa9baac3df102e0,
                  0x08b0e27988598881d920a9e64f5615cd]
    return checkTestVector1("MODE_CBC", "SIZE_192", key, plaintext, ciphertext, iv)

# Test vectors from NIST SP800-38A sections F.2.5 and F.2.6
def CBC_256_NOPAD():
    key = 0x603deb1015ca71be2b73aef0857d77811f352c073b6108d72d9810a30914dff4
    iv = 0x000102030405060708090a0b0c0d0e0f
    plaintext=  [0x6bc1bee22e409f96e93d7e117393172a,
                 0xae2d8a571e03ac9c9eb76fac45af8e51,
                 0x30c81c46a35ce411e5fbc1191a0a52ef,
                 0xf69f2445df4f9b17ad2b417be66c3710]
    ciphertext = [0xf58c4c04d6e5f1ba779eabfb5f7bfbd6,
                  0x9cfc4e967edb808d679f777bc6702c7d,
                  0x39f23369a9d9bacfa530e26304231461,
                  0xb2eb05e2c39be9fcda6c19078c6a9d1b]
    return checkTestVector1("MODE_CBC", "SIZE_256", key, plaintext, ciphertext, iv)

#===========================================================================

# Two multi-block encryptions in ECB mode with 128-bit key
def ECB_128_PKCS5():
    plaintext = "The quick brown fox jumps over the lazy dog"
    key1 = 0x000102030405060708090a0b0c0d0e0f
    key2 = 0xffeeddccbbaa99887766554433221100
    return checkTestVector2("MODE_ECB", "SIZE_128", plaintext, key1, key2)

# Two multi-block encryptions in ECB mode with 192-bit key
def ECB_192_PKCS5():
    plaintext = "The quick brown fox jumps over the lazy dog"
    key1 = 0x000102030405060708090a0b0c0d0e0f1011121314151617
    key2 = 0x1716151413121110ffeeddccbbaa99887766554433221100
    return checkTestVector2("MODE_ECB", "SIZE_192", plaintext, key1, key2)

# Two multi-block encryptions in ECB mode with 256-bit key
def ECB_256_PKCS5():
    plaintext = "The quick brown fox jumps over the lazy dog"
    key1 = 0x000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f
    key2 = 0x1f1e1d1c1b1a19181716151413121110ffeeddccbbaa99887766554433221100
    return checkTestVector2("MODE_ECB", "SIZE_256", plaintext, key1, key2)

# Two multi-block encryptions in CBC mode with 128-bit key
def CBC_128_PKCS5():
    plaintext = "The quick brown fox jumps over the lazy dog"
    key1 = 0x000102030405060708090a0b0c0d0e0f
    key2 = 0xffeeddccbbaa99887766554433221100
    iv1 = 0x0123cdef456789ab0123cdef456789ab
    iv2 = 0xab0123cdef456789ab0123cdef456789
    return checkTestVector2("MODE_CBC", "SIZE_128", plaintext, key1, key2, iv1, iv2)

# Two multi-block encryptions in CBC mode with 192-bit key
def CBC_192_PKCS5():
    plaintext = "The quick brown fox jumps over the lazy dog"
    key1 = 0x000102030405060708090a0b0c0d0e0f1011121314151617
    key2 = 0x1716151413121110ffeeddccbbaa99887766554433221100
    iv1 = 0x0123cdef456789ab0123cdef456789ab
    iv2 = 0xab0123cdef456789ab0123cdef456789
    return checkTestVector2("MODE_CBC", "SIZE_192", plaintext, key1, key2, iv1, iv2)

# Two multi-block encryptions in CBC mode with 256-bit key
def CBC_256_PKCS5():
    plaintext = "The quick brown fox jumps over the lazy dog"
    key1 = 0x000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f
    key2 = 0x1f1e1d1c1b1a19181716151413121110ffeeddccbbaa99887766554433221100
    iv1 = 0x0123cdef456789ab0123cdef456789ab
    iv2 = 0xab0123cdef456789ab0123cdef456789
    return checkTestVector2("MODE_CBC", "SIZE_256", plaintext, key1, key2, iv1, iv2)

def main():
    # Perform vector tests
    testSuccess = True
    for pad in ["NOPAD", "PKCS5"]:
        for mode in ["ECB", "CBC"]:
            for size in ["128", "192", "256"]:
                testVector = mode + "_" + size + "_" + pad + "()"
                testSuccess = testSuccess & eval(testVector)
    if testSuccess:
        print("All tests passed!")

if __name__ == '__main__':
    main()

DES implementation in Python

After coding RC4, I thought it was time to face a more difficult challenge: DES (Data Encryption Standard). Although it has been replaced by AES (Advanced Encryption Standard) more than a decade ago, it still has some appeal to me because it was the first symmetric algorithm I learned (and also because its creation, 40 years ago, was surrounded by mystery). Before fighting the monster, I faced a lighter opponent, S-DES (Simplified DES), to get used to the awkward bit manipulation DES take advantage of. Since S-DES is just a toy cryptographic algorithm, it isn’t worthwhile to spend much time and space writing about it. For this reason, without further ado, this is the DES Python code I wrote:

#!/usr/bin/python3
#
# Author: Joao H de A Franco (jhafranco@acm.org)
#
# Description: DES implementation in Python 3
#
# Note: only single DES in ECB mode (with PKCS5 padding)
#       is supported
#
# License: Attribution-NonCommercial-ShareAlike 3.0 Unported
#          (CC BY-NC-SA 3.0)
#===========================================================

subKeyList = 16*[[None]*8]

IPtable = (58, 50, 42, 34, 26, 18, 10, 2,
           60, 52, 44, 36, 28, 20, 12, 4,
           62, 54, 46, 38, 30, 22, 14, 6,
           64, 56, 48, 40, 32, 24, 16, 8,
           57, 49, 41, 33, 25, 17,  9, 1,
           59, 51, 43, 35, 27, 19, 11, 3,
           61, 53, 45, 37, 29, 21, 13, 5,
           63, 55, 47, 39, 31, 23, 15, 7)

EPtable = (32,  1,  2,  3,  4,  5,
            4,  5,  6,  7,  8,  9,
            8,  9, 10, 11, 12, 13,
           12, 13, 14, 15, 16, 17,
           16, 17, 18, 19, 20, 21,
           20, 21, 22, 23, 24, 25,
           24, 25, 26, 27, 28, 29,
           28, 29, 30, 31, 32,  1)

PFtable = (16,  7, 20, 21, 29, 12, 28, 17,
            1, 15, 23, 26,  5, 18, 31, 10,
            2,  8, 24, 14, 32, 27,  3,  9,
           19, 13, 30,  6, 22, 11,  4, 25)

FPtable = (40, 8, 48, 16, 56, 24, 64, 32,
           39, 7, 47, 15, 55, 23, 63, 31,
           38, 6, 46, 14, 54, 22, 62, 30,
           37, 5, 45, 13, 53, 21, 61, 29,
           36, 4, 44, 12, 52, 20, 60, 28,
           35, 3, 43, 11, 51, 19, 59, 27,
           34, 2, 42, 10, 50, 18, 58, 26,
           33, 1, 41,  9, 49, 17, 57, 25)

sBox = 8*[64*[0]]

sBox[0] = (14,  4, 13,  1,  2, 15, 11,  8,  3, 10,  6, 12,  5,  9,  0,  7,
            0, 15,  7,  4, 14,  2, 13,  1, 10,  6, 12, 11,  9,  5,  3,  8,
            4,  1, 14,  8, 13,  6,  2, 11, 15, 12,  9,  7,  3, 10,  5,  0,
           15, 12,  8,  2,  4,  9,  1,  7,  5, 11,  3, 14, 10,  0,  6, 13)

sBox[1] = (15,  1,  8, 14,  6, 11,  3,  4,  9,  7,  2, 13, 12,  0,  5, 10,
            3, 13,  4,  7, 15,  2,  8, 14, 12,  0,  1, 10,  6,  9, 11,  5,
            0, 14,  7, 11, 10,  4, 13,  1,  5,  8, 12,  6,  9,  3,  2, 15,
           13,  8, 10,  1,  3, 15,  4,  2, 11,  6,  7, 12,  0,  5, 14,  9)

sBox[2] = (10,  0,  9, 14,  6,  3, 15,  5,  1, 13, 12,  7, 11,  4,  2,  8,
           13,  7,  0,  9,  3,  4,  6, 10,  2,  8,  5, 14, 12, 11, 15,  1,
           13,  6,  4,  9,  8, 15,  3,  0, 11,  1,  2, 12,  5, 10, 14,  7,
            1, 10, 13,  0,  6,  9,  8,  7,  4, 15, 14,  3, 11,  5,  2, 12)

sBox[3] = ( 7, 13, 14,  3,  0,  6,  9, 10,  1,  2,  8,  5, 11, 12,  4, 15,
           13,  8, 11,  5,  6, 15,  0,  3,  4,  7,  2, 12,  1, 10, 14,  9,
           10,  6,  9,  0, 12, 11,  7, 13, 15,  1,  3, 14,  5,  2,  8,  4,
            3, 15,  0,  6, 10,  1, 13,  8,  9,  4,  5, 11, 12,  7,  2, 14)

sBox[4] = ( 2, 12,  4,  1,  7, 10, 11,  6,  8,  5,  3, 15, 13,  0, 14,  9,
           14, 11,  2, 12,  4,  7, 13,  1,  5,  0, 15, 10,  3,  9,  8,  6,
            4,  2,  1, 11, 10, 13,  7,  8, 15,  9, 12,  5,  6,  3,  0, 14,
           11,  8, 12,  7,  1, 14,  2, 13,  6, 15,  0,  9, 10,  4,  5,  3)

sBox[5] = (12,  1, 10, 15,  9,  2,  6,  8,  0, 13,  3,  4, 14,  7,  5, 11,
           10, 15,  4,  2,  7, 12,  9,  5,  6,  1, 13, 14,  0, 11,  3,  8,
            9, 14, 15,  5,  2,  8, 12,  3,  7,  0,  4, 10,  1, 13, 11,  6,
            4,  3,  2, 12,  9,  5, 15, 10, 11, 14,  1,  7,  6,  0,  8, 13)

sBox[6] = ( 4, 11,  2, 14, 15,  0,  8, 13,  3, 12,  9,  7,  5, 10,  6,  1,
           13,  0, 11,  7,  4,  9,  1, 10, 14,  3,  5, 12,  2, 15,  8,  6,
            1,  4, 11, 13, 12,  3,  7, 14, 10, 15,  6,  8,  0,  5,  9,  2,
            6, 11, 13,  8,  1,  4, 10,  7,  9,  5,  0, 15, 14,  2,  3, 12)

sBox[7] = (13,  2,  8,  4,  6, 15, 11,  1, 10,  9,  3, 14,  5,  0, 12,  7,
            1, 15, 13,  8, 10,  3,  7,  4, 12,  5,  6, 11,  0, 14,  9,  2,
            7, 11,  4,  1,  9, 12, 14,  2,  0,  6, 10, 13, 15,  3,  5,  8,
            2,  1, 14,  7,  4, 10,  8, 13, 15, 12,  9,  0,  3,  5,  6, 11)

def bit2Byte(bitList):
    """Convert bit list into a byte list"""
    return [int("".join(map(str,bitList[i*8:i*8+8])),2) for i in range(len(bitList)//8)]

def byte2Bit(byteList):
    """Convert byte list into a bit list"""
    return [(byteList[i//8]>>(7-(i%8)))&0x01 for i in range(8*len(byteList))]

def permBitList(inputBitList,permTable):
    """Permute input bit list according to input permutation table"""
    return [inputBitList[e - 1] for e in permTable]

def permByteList(inByteList,permTable):
    """Permute input byte list according to input permutation table"""
    outByteList = (len(permTable)>>3)*[0]
    for index,elem in enumerate(permTable):
        i = index%8
        e = (elem-1)%8
        if i>=e:
            outByteList[index>>3] |= \
                (inByteList[(elem-1)>>3]&(128>>e))>>(i-e)
        else:
            outByteList[index>>3] |= \
                (inByteList[(elem-1)>>3]&(128>>e))<<(e-i)
    return outByteList

def getIndex(inBitList):
    """Permute bits to properly index the S-boxes"""
    return (inBitList[0]<<5)+(inBitList[1]<<3)+ \
           (inBitList[2]<<2)+(inBitList[3]<<1)+ \
           (inBitList[4]<<0)+(inBitList[5]<<4)

def padData(string):
    """Add PKCS5 padding to plaintext"""
    padLength = 8-(len(string)%8)
    return [ord(s) for s in string]+padLength*[padLength]

def unpadData(byteList):
    """Remove PKCS5 padding from plaintext"""
    return "".join(chr(e) for e in byteList[:-byteList[-1]])

def setKey(keyByteList):
    """Generate all sixteen round subkeys"""
    PC1table = (57, 49, 41, 33, 25, 17,  9,
                 1, 58, 50, 42, 34, 26, 18,
                10,  2, 59, 51, 43, 35, 27,
                19, 11,  3, 60, 52, 44, 36,
                63, 55, 47, 39, 31, 23, 15,
                 7, 62, 54, 46, 38, 30, 22,
                14,  6, 61, 53, 45, 37, 29,
                21, 13,  5, 28, 20, 12,  4)

    PC2table= (14, 17, 11, 24,  1,  5,  3, 28,
               15,  6, 21, 10, 23, 19, 12,  4,
               26,  8, 16,  7, 27, 20, 13,  2,
               41, 52, 31, 37, 47, 55, 30, 40,
               51, 45, 33, 48, 44, 49, 39, 56,
               34, 53, 46, 42, 50, 36, 29, 32)

    def leftShift(inKeyBitList,round):
        """Perform one (or two) circular left shift(s) on key"""
        LStable = (1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 1)

        outKeyBitList = 56*[0]
        if LStable[round] == 2:
            outKeyBitList[:26] = inKeyBitList[2:28]
            outKeyBitList[26] = inKeyBitList[0]
            outKeyBitList[27] = inKeyBitList[1]
            outKeyBitList[28:54] = inKeyBitList[30:]
            outKeyBitList[54] = inKeyBitList[28]
            outKeyBitList[55] = inKeyBitList[29]
        else:
            outKeyBitList[:27] = inKeyBitList[1:28]
            outKeyBitList[27] = inKeyBitList[0]
            outKeyBitList[28:55] = inKeyBitList[29:]
            outKeyBitList[55] = inKeyBitList[28]
        return outKeyBitList

    permKeyBitList = permBitList(byte2Bit(keyByteList),PC1table)
    for round in range(16):
        auxBitList = leftShift(permKeyBitList,round)
        subKeyList[round] = bit2Byte(permBitList(auxBitList,PC2table))
        permKeyBitList = auxBitList

def encryptBlock(inputBlock):
    """Encrypt an 8-byte block with already defined key"""
    inputData = permByteList(inputBlock,IPtable)
    leftPart,rightPart = inputData[:4],inputData[4:]
    for round in range(16):
        expRightPart = permByteList(rightPart,EPtable)
        key = subKeyList[round]
        indexList = byte2Bit([i^j for i,j in zip(key,expRightPart)])
        sBoxOutput = 4*[0]
        for nBox in range(4):
            nBox12 = 12*nBox
            leftIndex = getIndex(indexList[nBox12:nBox12+6])
            rightIndex = getIndex(indexList[nBox12+6:nBox12+12])
            sBoxOutput[nBox] = (sBox[nBox<<1][leftIndex]<<4)+ \
                                sBox[(nBox<<1)+1][rightIndex]
        aux = permByteList(sBoxOutput,PFtable)
        newRightPart = [i^j for i,j in zip(aux,leftPart)]
        leftPart = rightPart
        rightPart = newRightPart
    return permByteList(rightPart+leftPart,FPtable)

def decryptBlock(inputBlock):
    """Decrypt an 8-byte block with already defined key"""
    inputData = permByteList(inputBlock,IPtable)
    leftPart,rightPart = inputData[:4],inputData[4:]
    for round in range(16):
        expRightPart = permByteList(rightPart,EPtable)
        key = subKeyList[15-round]
        indexList = byte2Bit([i^j for i,j in zip(key,expRightPart)])
        sBoxOutput = 4*[0]
        for nBox in range(4):
            nBox12 = 12*nBox
            leftIndex = getIndex(indexList[nBox12:nBox12+6])
            rightIndex = getIndex(indexList[nBox12+6:nBox12+12])
            sBoxOutput[nBox] = (sBox[nBox*2][leftIndex]<<4)+ \
                                sBox[nBox*2+1][rightIndex]
        aux = permByteList(sBoxOutput,PFtable)
        newRightPart = [i^j for i,j in zip(aux,leftPart)]
        leftPart = rightPart
        rightPart = newRightPart
    return permByteList(rightPart+leftPart,FPtable)

def encrypt(key, inString):
    """Encrypt plaintext with given key"""
    setKey(key)
    inByteList,outByteList = padData(inString),[]
    for i in range(0,len(inByteList),8):
        outByteList += encryptBlock(inByteList[i:i+8])
    return outByteList

def decrypt(key, inByteList):
    """Decrypt ciphertext with given key"""
    setKey(key)
    outByteList = []
    for i in range(0,len(inByteList),8):
        outByteList += decryptBlock(inByteList[i:i+8])
    return unpadData(outByteList)

To verify this DES implementation, I also wrote a separate Python module (shown below) containing an interesting algorithm proposed by Ron Rivest a long time ago (1985) in the paper Testing Implementation of DES. I also added a simple multi-block test.

import sys
from pyDES import setKey, encryptBlock, decryptBlock, encrypt, decrypt

def sanityCheck1():
    """Tests single-block DES encryption & decryption
       using algorithm proposed by Ronald Rivest
       (http://people.csail.mit.edu/rivest/Destest.txt)"""
    x0 =  [0x94, 0x74, 0xb8, 0xe8, 0xc7, 0x3b, 0xca, 0x7d]
    x16 = [0x1b, 0x1a, 0x2d, 0xdb, 0x4c, 0x64, 0x24, 0x38]
    x = x0
    for i in range(16):
        setKey(x)
        if i % 2 == 0:
            x = encryptBlock(x) # if i is even, x[i+1] = E(x[i], x[i)
        else:
            x = decryptBlock(x) # if i is odd, x[i+1] = D(x[i], x[i)
    try:
        assert x == x16
    except AssertionError:
        return False
    return True

def sanityCheck2():
    """Tests multi-block DES encryption and decryption"""
    try:
        key = [0x0f, 0x15, 0x71, 0xc9, 0x47, 0xd9, 0xe8, 0x59]
        plaintext = "The quick brown fox jumps over the lazy dog"
        ciphertext = encrypt(key, plaintext)
        assert decrypt(key, ciphertext) == plaintext
    except AssertionError:
        return False
    return True

def main():
    if sanityCheck1() and sanityCheck2():
        print("All DES tests ok!")
    else:
        sys.exit(1)
    sys.exit()

if __name__ == '__main__':
    main()

RC4 implementation in Python

I started learning Python two months ago. To get the most out of the process, I decided to combine it with another interest of mine, cryptography, by trying to implement a very simple symmetric algorithm, RC4. Here is the code:

#!/usr/bin/python3
#
# Author: Joao H de A Franco (jhafranco@acm.org)
#
# Description: RC4 implementation in Python 3
#
# License: Attribution-NonCommercial-ShareAlike 3.0 Unported
#          (CC BY-NC-SA 3.0)
#===========================================================

# Global variables
state = [None] * 256
p = q = None

def setKey(key):
    """RC4 Key Scheduling Algorithm (KSA)"""
    global p, q, state
    state = [n for n in range(256)]
    p = q = j = 0
    for i in range(256):
        if len(key) > 0:
            j = (j + state[i] + key[i % len(key)]) % 256
        else:
            j = (j + state[i]) % 256
        state[i], state[j] = state[j], state[i]

def byteGenerator():
    """RC4 Pseudo-Random Generation Algorithm (PRGA)"""
    global p, q, state
    p = (p + 1) % 256
    q = (q + state[p]) % 256
    state[p], state[q] = state[q], state[p]
    return state[(state[p] + state[q]) % 256]

def encrypt(inputString):
    """Encrypt input string returning a byte list"""
    return [ord(p) ^ byteGenerator() for p in inputString]

def decrypt(inputByteList):
    """Decrypt input byte list returning a string"""
    return "".join([chr(c ^ byteGenerator()) for c in inputByteList])

To informally verify the correctness of this implementation, I wrote a separate Python module that uses Wikipedia’s RC4 test vectors for this purpose:

import sys
import pyRC4

def main():
    """Verify the correctness of RC4 implementation using Wikipedia
       test vectors"""

    def intToList(inputNumber):
        """Convert a number into a byte list"""
        inputString = "{:02x}".format(inputNumber)
        return [int(inputString[i:i + 2], 16) for i in range(0, len(inputString), 2)]

    def string_to_list(inputString):
        """Convert a string into a byte list"""
        return [ord(c) for c in inputString]

    def test(key, plaintext, ciphertext, testNumber):
        success = True
        pyRC4.setKey(string_to_list(key))
        try:
            assert pyRC4.encrypt(plaintext) == intToList(ciphertext)
            print("RC4 encryption test #{:d} ok!".format(testNumber))
        except AssertionError:
            print("RC4 encryption test #{:d} failed".format(testNumber))
            success = False
        pyRC4.setKey(string_to_list(key))
        try:
            assert pyRC4.decrypt(intToList(ciphertext)) == plaintext
            print("RC4 decryption test #{:d} ok!".format(testNumber))
        except AssertionError:
            print("RC4 decryption test #{:d} failed".format(testNumber))
            success = False
        return success

    # Test vectors definition section
    numberOfTests = 3
    testVectorList = [{}] * numberOfTests

    # Wikipedia test vector #1
    testVectorList[0] = dict(key = "Key",
                             plaintext = "Plaintext",
                             ciphertext = 0xBBF316E8D940AF0AD3,
                             testNumber = 1)

    # Wikipedia test vector #2
    testVectorList[1] = dict(key = "Wiki",
                             plaintext = "pedia",
                             ciphertext = 0x1021BF0420,
                             testNumber = 2)

    # Wikipedia test vector #3
    testVectorList[2] = dict(key = "Secret",
                             plaintext = "Attack at dawn",
                             ciphertext = 0x45A01F645FC35B383552544B9BF5,
                             testNumber = 3)

    # Testing section
    testSuccess = True
    for p in range(numberOfTests):
        testSuccess &= test(**testVectorList[p])
    if testSuccess:
        print("All RC4 tests succeeded!")
    else:
        print("At least one RC4 test failed")
        sys.exit(1)
    sys.exit()

if __name__ == '__main__':
    main()