' Module: RSAES_OAEP.bas 

' This module contains the generic RSAES_OAEP functions
' (rsaes_oaep_Encrypt and rsaes_oaep_Decrypt),
' the mask generation function MGF1, and byte manipulation functions.
' It uses functions from the CryptoSys (tm) PKI Toolkit available from
' <www.cryptosys.net/pki/>.

' Ref: PKCS #1: RSA Cryptography Specifications Version 2.1 (PKCS-1v2-1)
' RSAEP = RSA Encryption Primitive
' OAEP = Optimal Asymmetric Encryption Primitive

'***************** COPYRIGHT NOTICE *******************
' This code was originally written by David Ireland and
' is copyright (C) 2005 DI Management Services Pty Ltd
' <www.di-mgt.com.au>.
' Provided "as is". No warranties. Use at own risk.
' It is not to be altered or distributed,
' except as part of an application.
' You are free to use it in any application,
' provided this copyright notice is left unchanged.
'************** END OF COPYRIGHT NOTICE ***************

Option Explicit
Public Const SHA1_BYTE_LEN As Long = 20

'***************************************
' GENERAL PURPOSE RSAES-OAEP FUNCTIONS *
'***************************************
Public Function rsaes_oaep_Encrypt(abMessage() As Byte, abSeed() As Byte, strPublicKey As String) As Variant
' Encrypts a message using the RSA public key passed as a string and the 20-byte seed supplied by caller.
' Uses all the default RSAES-OAEP values; namely, SHA-1 as the hash function, MGF1
' as the mask generation function with SHA-1, and the empty string for the encoding parameters, P.
' Assumes the byte arrays contain exactly the correct amount of bytes.
' Returns the resulting ciphertext as an array of bytes passed as a Variant, or an empty value on error.
    Dim lngRet As Long
    Dim abDummy(0) As Byte
    Dim abDB() As Byte
    Dim abLHash() As Byte
    Dim abDbMask() As Byte
    Dim abMaskedDB() As Byte
    Dim abSeedMask() As Byte
    Dim abMaskedSeed() As Byte
    Dim abBlock() As Byte
    ' All lengths are in octets (i.e. 8-bit bytes)
    Dim nhLen As Long   ' length of hash = 20 for SHA-1
    Dim nkLen As Long   ' length, k, of RSA key modulus
    Dim nmLen As Long   ' length of message, M
    Dim npsLen As Long  ' length of padding string, PS
    Dim ndbLen As Long  ' length of data block, DB
    Dim i As Long
    Dim iOffset As Long
        
    rsaes_oaep_Encrypt = abDummy
    
    ' Compute Hash(L, the empty string) in byte array format
    ' --we know the resulting digest is going to be 20 bytes long
    nhLen = SHA1_BYTE_LEN
    ReDim abLHash(nhLen - 1)
    lngRet = HASH_Bytes(abLHash(0), nhLen, 0, 0, PKI_HASH_SHA1)
    Debug.Print "lHash: " & cnvHexStrFromBytes(abLHash)
    
    ' Compute lengths
    nmLen = BytesLength(abMessage)
    nkLen = RSA_KeyBytes(strPublicKey)
    Debug.Print "Key is " & nkLen & " bytes long"
    Debug.Print "Message is " & nmLen & " bytes long"
    ndbLen = nkLen - 1 - nhLen
    npsLen = ndbLen - nmLen - nhLen - 1
    ' Catch error
    If npsLen < 0 Then
        MsgBox "Message is too long", vbCritical
        Exit Function
    End If
    
    ' Construct DB = lHash || Padding || M
    ReDim abDB(ndbLen - 1)
    iOffset = 0
    ' Copy lHash into DB
    For i = 0 To nhLen - 1
        abDB(iOffset) = abLHash(i)
        iOffset = iOffset + 1
    Next
    ' Followed by npsLen zero bytes
    ' --we don't have to do this as the bytes should already be zero
    For i = 0 To npsLen - 1
        abDB(iOffset) = 0
        iOffset = iOffset + 1
    Next
    ' Followed by a single 0x01 byte
    abDB(iOffset) = &H1
    iOffset = iOffset + 1
    ' And then the message itself
    For i = 0 To nmLen - 1
        abDB(iOffset) = abMessage(i)
        iOffset = iOffset + 1
    Next
    Debug.Print "DB: " & cnvHexStrFromBytes(abDB)
    
    ' Compute dbMask = MGF(seed, length(DB))
    abDbMask = MGF1(abSeed, ndbLen)
    Debug.Print "dbMask:   " & cnvHexStrFromBytes(abDbMask)
    
    ' Compute maskedDB = DB XOR dbMask
    abMaskedDB = XorBytes(abDB, abDbMask)
    Debug.Print "maskedDB: " & cnvHexStrFromBytes(abMaskedDB)
    
    ' Compute seedMask = MGF(maskedDB, length(seed))
    abSeedMask = MGF1(abMaskedDB, nhLen)
    Debug.Print "dbSeedMask: " & cnvHexStrFromBytes(abSeedMask)
    
    ' Compute maskedSeed = seed XOR seedMask
    abMaskedSeed = XorBytes(abSeed, abSeedMask)
    Debug.Print "maskedSeed: " & cnvHexStrFromBytes(abMaskedSeed)
    
    ' Build EM = 00 || maskedSeed || maskedDB
    ' --be careful, EM is different in the examples and in PKCS#1
    ' --in PKCS#1, EM does not include the leading zero.
    ReDim abBlock(0)
    abBlock(0) = 0
    abBlock = AppendBytes(abBlock, abMaskedSeed)
    abBlock = AppendBytes(abBlock, abMaskedDB)
    Debug.Print "EM: " & cnvHexStrFromBytes(abBlock)
    
    ' Encrypt using RSA public key
    lngRet = RSA_RawPublic(abBlock(0), nkLen, strPublicKey, 0)
    Debug.Print "CT: " & cnvHexStrFromBytes(abBlock)
    
    ' Return ciphertext block
    rsaes_oaep_Encrypt = abBlock
End Function

Public Function rsaes_oaep_Decrypt(abCipher() As Byte, strPrivateKey As String) As Variant
' Decrypts RSA-encrypted ciphertext using the RSA private key passed as a string.
' Uses all the default RSAES-OAEP values; namely, SHA-1 as the hash function, MGF1
' as the mask generation function with SHA-1, and the empty string for the encoding parameters, P.
' Returns the resulting message as an array of bytes passed as a Variant, or an empty value on error.
    Dim abDummy(0) As Byte
    Dim abMessage() As Byte
    Dim lngRet As Long
    Dim abSeed() As Byte
    Dim abDB() As Byte
    Dim abLHash() As Byte
    Dim abPHash() As Byte
    Dim abDbMask() As Byte
    Dim abMaskedDB() As Byte
    Dim abSeedMask() As Byte
    Dim abMaskedSeed() As Byte
    Dim abBlock() As Byte
    Dim nhLen As Long   ' length of hash = 20 for SHA-1
    Dim nkLen As Long   ' length, k, of RSA key modulus
    Dim nmLen As Long   ' length of message, M
    Dim npsLen As Long  ' length of padding string, PS
    Dim ndbLen As Long  ' length of data block, DB
    Dim i As Long
    Dim iOffset As Long
    
    rsaes_oaep_Decrypt = abDummy
    
    nkLen = RSA_KeyBytes(strPrivateKey)
    Debug.Print "RSA key is " & nkLen & " bytes long (" & RSA_KeyBits(strPrivateKey) & ") bits"
    
    abBlock = abCipher
    
    ' Decrypt
    lngRet = RSA_RawPrivate(abBlock(0), nkLen, strPrivateKey, 0)
    If lngRet <> 0 Then
        MsgBox "Decryption error", vbCritical
        Exit Function
    End If
    
    Debug.Print "EM: " & cnvHexStrFromBytes(abBlock)
    
    ' EM = 00 || maskedSeed || maskedDB
    ' so check leading byte is zero and then split up
    If abBlock(0) <> 0 Then
        MsgBox "Decryption error", vbCritical
        Exit Function
    End If
    
    ' Now we decode according to EME-OAEP-DECODE
    nhLen = SHA1_BYTE_LEN
    ndbLen = nkLen - nhLen - 1
    ReDim abMaskedSeed(nhLen - 1)
    ReDim abMaskedDB(ndbLen - 1)
    iOffset = 1
    For i = 0 To nhLen - 1
        abMaskedSeed(i) = abBlock(iOffset)
        iOffset = iOffset + 1
    Next
    Debug.Print "maskedSeed: " & cnvHexStrFromBytes(abMaskedSeed)
    For i = 0 To ndbLen - 1
        abMaskedDB(i) = abBlock(iOffset)
        iOffset = iOffset + 1
    Next
    Debug.Print "maskedDB:   " & cnvHexStrFromBytes(abMaskedDB)
    
    ' Let seedMask = MGF(maskedDB, hLen)
    abSeedMask = MGF1(abMaskedDB, nhLen)
    Debug.Print "seedMask: " & cnvHexStrFromBytes(abSeedMask)
    
    ' Let seed = maskedSeed \xor seedMask
    abSeed = XorBytes(abMaskedSeed, abSeedMask)
    Debug.Print "seed:     " & cnvHexStrFromBytes(abSeed)
    
    ' Let dbMask = MGF(seed, ||EM|| - hLen)
    abDbMask = MGF1(abSeed, ndbLen)
    Debug.Print "dbMask: " & cnvHexStrFromBytes(abDbMask)
    
    ' Let DB = maskedDB \xor dbMask
    abDB = XorBytes(abMaskedDB, abDbMask)
    Debug.Print "DB:     " & cnvHexStrFromBytes(abDB)
    
    ' Let pHash = Hash(P), an octet string of length hLen
    ' where P is the empty string in this implementation
    ReDim abPHash(nhLen - 1)
    lngRet = HASH_Bytes(abPHash(0), nhLen, 0, 0, PKI_HASH_SHA1)
    Debug.Print "pHash:  " & cnvHexStrFromBytes(abPHash)
    
    ' Separate DB = pHash' || PS || 01 || M
    ' If there is no 01 octet to separate PS from M, output "decoding error" and stop.
    ' If pHash' does not equal pHash, output "decoding error" and stop.
    ReDim abLHash(nhLen - 1)
    For i = 0 To nhLen - 1
        If abDB(i) <> abPHash(i) Then
            MsgBox "Decryption error", vbCritical
            Exit Function
        End If
    Next
    
    ' Work through the zero bytes of PS, if any, to find the 01 byte
    For i = nhLen To ndbLen - 1
        If abDB(i) <> 0 Then
            Exit For
        End If
    Next
    If i >= ndbLen Then ' Too long - no message at all
        MsgBox "Decryption error", vbCritical
        Exit Function
    End If
    If abDB(i) <> &H1 Then
        MsgBox "Decryption error", vbCritical
        Exit Function
    End If
    ' The remainder of DB is M, the message
    iOffset = i + 1
    nmLen = ndbLen - iOffset
    Debug.Print "M is " & nmLen & " bytes long"
    ReDim abMessage(nmLen - 1)
    For i = 0 To nmLen - 1
        abMessage(i) = abDB(iOffset + i)
    Next
    
    ' Output M.
    Debug.Print "M:  " & cnvHexStrFromBytes(abMessage)
    rsaes_oaep_Decrypt = abMessage

End Function

'**************************************************
' EXAMPLE TO ENCRYPT AND DECRYPT USING RSAES-OAEP *
'**************************************************
Public Function Test_Encrypt()
    Dim abMessage() As Byte
    Dim abSeed() As Byte
    Dim abResult() As Byte
    Dim abCheck() As Byte
    Dim strPublicKey As String
    Dim strPubKeyFile As String
    Dim strPrivateKey As String
    Dim strPriKeyFile As String
    ' IMPORTANT: change this to suit your system
    Const TEST_KEY_PATH As String = "C:\Test\"
    
    strPubKeyFile = TEST_KEY_PATH & "pubkeyex1.bin"
    strPriKeyFile = TEST_KEY_PATH & "prikeyex1.bin"
    
    ' Convert ANSI text to bytes
    abMessage = StrConv("Hello world!", vbFromUnicode)
    
    ' Generate a seed
    ReDim abSeed(SHA1_BYTE_LEN - 1)
    Call RNG_Bytes(abSeed(0), SHA1_BYTE_LEN, "", 0)
    
    Debug.Print "M (ansi): " & StrConv(abMessage, vbUnicode)
    Debug.Print "M (hex):  " & cnvHexStrFromBytes(abMessage)
    Debug.Print "seed:     " & cnvHexStrFromBytes(abSeed)
    
    strPublicKey = rsaReadPublicKey(strPubKeyFile)
    If Len(strPublicKey) = 0 Then
        MsgBox "Cannot read RSA key file '" & strPubKeyFile & "'", vbCritical
        Exit Function
    End If
    
    ' Encrypt using freshly-generated seed
    abResult = rsaes_oaep_Encrypt(abMessage, abSeed, strPublicKey)

    ' Now decrypt with private key and compare with original message
    strPrivateKey = readPrivateKeyInfo(strPriKeyFile)
    If Len(strPrivateKey) = 0 Then
        MsgBox "Cannot read RSA key file '" & strPriKeyFile & "'", vbCritical
        Exit Function
    End If
    
    abCheck = rsaes_oaep_Decrypt(abResult, strPrivateKey)
    
    ' Clear the private key for security
    Call WIPE_String(strPrivateKey, Len(strPrivateKey))
    
    Debug.Print "PT: " & cnvHexStrFromBytes(abCheck)
    Debug.Print "OK: " & cnvHexStrFromBytes(abMessage)
    Debug.Print "ANSI: " & StrConv(abCheck, vbUnicode)

End Function

'*******************************
' FUNCTION TO READ PRIVATE KEY *
'*******************************
' Like rsaReadPrivateKey but for an UNencrypted private key info file
Public Function readPrivateKeyInfo(strKeyFile As String) As String
' Returns the key as a base64 string or an empty string on error
    Dim lngKeyLen As Long
    Dim lngRet As Long
    Dim strKey As String
    ' How long is key string?
    lngKeyLen = RSA_ReadPrivateKeyInfo("", 0, strKeyFile, 0)
    If lngKeyLen <= 0 Then
        Exit Function
    End If
    ' Pre-dimension the string to receive data
    strKey = String(lngKeyLen, " ")
    ' Read in the Private Key
    lngRet = RSA_ReadPrivateKeyInfo(strKey, lngKeyLen, strKeyFile, 0)
    readPrivateKeyInfo = strKey

End Function

'***************************
' MASK GENERATION FUNCTION *
'***************************
Public Function MGF1(abZ() As Byte, nMaskLen As Long) As Variant
' Mask generation function MGF1 from PKCS-1v2
    Const nchLen As Long = SHA1_BYTE_LEN   ' Known constant for SHA-1
    Dim abDigest(nchLen - 1) As Byte
    Dim abResult() As Byte
    Dim abBlock() As Byte
    Dim abCnt() As Byte
    Dim nBlockLen As Long
    Dim i As Long
    Dim iCount As Long
    Dim nBlocks As Long
    
    If nMaskLen <= 0 Then Exit Function
    
    ' How many blocks do we need? = ceil(l/hLen)
    nBlocks = (nMaskLen + nchLen - 1) \ nchLen
    
    ' We need to compute Hash(Z || C) where C is a 4-byte representation of count
    nBlockLen = BytesLength(abZ) + 4
    
    ' Copy input Z into our message block
    abBlock = abZ
    ReDim Preserve abBlock(nBlockLen - 1)
    
    For iCount = 0 To nBlocks - 1
        abCnt = I2OSP4(iCount)
        Debug.Print "   mgf: C=" & cnvHexStrFromBytes(abCnt)
        ' Set block = (Z || C) by just changing the last 4 bytes
        For i = 0 To 3
            abBlock(nBlockLen - 4 + i) = abCnt(i)
        Next
        Debug.Print "   mgf: (Z || C)=" & cnvHexStrFromBytes(abBlock)
        ' Compute Hash(Z || C)
        Call HASH_Bytes(abDigest(0), nchLen, abBlock(0), nBlockLen, PKI_HASH_SHA1)
        Debug.Print "   mgf: Hash=" & cnvHexStrFromBytes(abDigest)
       ' Append to output
        abResult = AppendBytes(abResult, abDigest)
    Next
        
    ' Truncate to required length
    ReDim Preserve abResult(nMaskLen - 1)
    
    ' Return the byte array as a variant
    MGF1 = abResult
End Function

Public Function test_MGF1()
    Dim abZ() As Byte
    Dim abMGF() As Byte
    
    abZ = cnvBytesFromHexStr("012345ff")
    abMGF = MGF1(abZ, 25)
    Debug.Print "MGF1=" & cnvHexStrFromBytes(abMGF)
End Function

'******************************
' BYTE MANIPULATION FUNCTIONS *
'******************************

Public Function AppendBytes(ByRef abBytes() As Byte, abToAdd() As Byte) As Variant
' Appends abToAdd to end of abBytes and returns new byte array as a variant
' Copes with empty input arrays
    Dim abResult() As Byte
    Dim iOffset As Long
    Dim i As Long
    Dim nLen As Long
    
    iOffset = BytesLength(abBytes)
    nLen = BytesLength(abToAdd)
    ' Copy input
    abResult = abBytes
    ' Resize
    ReDim Preserve abResult(iOffset + nLen - 1)
    ' Append bytes
    For i = 0 To nLen - 1
        abResult(i + iOffset) = abToAdd(i)
    Next
    
    AppendBytes = abResult
End Function

Public Function BytesLength(abBytes() As Byte) As Long
    ' Trap error if array is empty
    On Error Resume Next
    BytesLength = UBound(abBytes) - LBound(abBytes) + 1
End Function

Public Function XorBytes(ab1() As Byte, ab2() As Byte) As Variant
    Dim abResult() As Byte
    Dim nLen1 As Long
    Dim nLen2 As Long
    Dim i As Long
    
    nLen1 = BytesLength(ab1)
    nLen2 = BytesLength(ab2)
    
    If nLen1 > nLen2 Then nLen1 = nLen2
    ReDim abResult(nLen1 - 1)
    
    For i = 0 To nLen1 - 1
        abResult(i) = ab1(i) Xor ab2(i)
    Next
    
    XorBytes = abResult
    
End Function

Public Function BytesAreEqual(ab1() As Byte, ab2() As Byte) As Boolean
    Dim nLen1 As Long
    Dim nLen2 As Long
    
    BytesAreEqual = False
    
    nLen1 = BytesLength(ab1)
    nLen2 = BytesLength(ab2)
    
    ' Catch empty byte arrays
    If nLen1 = 0 And nLen2 = 0 Then
        ' Both empty and therefore equal
        BytesAreEqual = True
        Exit Function
    End If
    
    If nLen1 = 0 Or nLen2 = 0 Then
        Exit Function
    End If

    ' Now we can compare using StrConv
    If StrConv(ab1, vbUnicode) = StrConv(ab2, vbUnicode) Then
        BytesAreEqual = True
    End If
    
End Function

Public Function I2OSP4(nVal As Long) As Variant
' Returns a 4-byte array representation of nVal
    Dim abData(3) As Byte
    ' 4-byte big-endian form
    abData(0) = (nVal \ &H1000000) And &HFF
    abData(1) = (nVal \ &H10000) And &HFF
    abData(2) = (nVal \ &H100) And &HFF
    abData(3) = nVal And &HFF
    I2OSP4 = abData
    
End Function