Skip to content

Instantly share code, notes, and snippets.

@odzhan
Created February 15, 2025 21:59
Show Gist options
  • Save odzhan/e8404e0f7cbee3d5793a536a3cfe5fc7 to your computer and use it in GitHub Desktop.
Save odzhan/e8404e0f7cbee3d5793a536a3cfe5fc7 to your computer and use it in GitHub Desktop.
Simple Echo Servers Using Symmetric Encryption
/*
** Demo Public Key OpenSSL Echo Client
**
** Connects to an PK Echo Server. Generates RSA Key and
** sends encoded public key to server. Response is a
** session key encoded with the public key. Sends client
** request, encoded with session key, consisting of
** header and text to be echo'd between lines 'BEGIN'
** and 'END' (inclusive). Reads server response, also
** encoded with session key, consisting of header and
** displays echo'd text between lines 'BEGIN' and
** 'END' (exclusive).
**
** Arguments:
** -t <host> Target host name (default 'localhost')
** -p <port> Target port number (default 16903)
** -e AES Mode ECB.
** -v Verbose
** -h Help
*/
#include <stdio.h>
#include <stdlib.h>
#include <errno.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <netdb.h>
#include <unistd.h>
#include <string.h>
#include <openssl/bn.h>
#include <openssl/rsa.h>
#include <openssl/evp.h>
#define CRYPT_IN_PLACE
#define DEFAULT_HOST "localhost"
#define DEFAULT_PORT 16903
#define RSA_EXPONENT 65537
#define AES128_ECB "aes-128-ecb"
#define AES128_CBC "aes-128-cbc"
#define AES256_ECB "aes-256-ecb"
#define AES256_CBC "aes-256-cbc"
typedef unsigned char BYTE;
extern char *optarg;
static int useECB = 0;
static int verbose = 0;
static char *REQUEST_TEMPLATE =
"PK Echo Client: openssl\r\n"
"BEGIN\r\n"
"PK Client/Server Echo Test\r\n"
"Host: %s:%d\r\n"
"END\r\n";
static RSA *rsaKey = NULL;
static int aes256 = 1;
static BYTE aesKey[ EVP_MAX_KEY_LENGTH ];
static BYTE cbcIV[ EVP_MAX_IV_LENGTH ];
static const EVP_CIPHER *cipherAes128ECB = NULL;
static const EVP_CIPHER *cipherAes128CBC = NULL;
static const EVP_CIPHER *cipherAes256ECB = NULL;
static const EVP_CIPHER *cipherAes256CBC = NULL;
static EVP_CIPHER_CTX encCtx;
static EVP_CIPHER_CTX decCtx;
static int errExit( char * );
static int tcpConnect( char *, int );
static int initSession( int );
static void clientRequest( int, char *, int );
static int initCrypt();
static void termCrypt();
static int encodePK( int, BYTE * );
static int decodeSK( int, BYTE * );
static int encode( int, BYTE *, int );
static int decode( int, BYTE *, int );
static int sendBytes( int, BYTE *, int );
static int receiveBytes( int, BYTE *, int, int * );
static void hexDump( BYTE *, int );
int
main( int argc, char **argv )
{
int c, sock;
char *host = DEFAULT_HOST;
int port = DEFAULT_PORT;
while( (c = getopt( argc, argv, "ehp:t:v" )) != -1 )
switch(c)
{
case 'h':
printf( "-t <host>\tTarget host name (default 'localhost')\n" );
printf( "-p <port>\tTarget port number (default 16903)\n" );
printf( "-e\t\tAES Mode ECB\n" );
printf( "-v\t\tVerbose\n" );
exit(0);
case 't':
if ( ! (host = strdup( optarg )) )
errExit( "Out of memory" );
break;
case 'p':
if ( ! (port = atoi( optarg )) )
errExit( "Invalid port specified" );
break;
case 'e': useECB = 1; break;
case 'v': verbose = 1; break;
}
sock = tcpConnect( host, port );
if ( ! initCrypt() ) exit( 1 );
if ( ! initSession( sock ) ) exit(1);
clientRequest( sock, host, port );
termCrypt();
shutdown( sock, 2 );
close( sock );
exit(0);
}
static int
errExit( char *string )
{
fprintf( stderr, "%s\n", string );
exit(0);
}
static int
tcpConnect( char *host, int port )
{
struct hostent *hp;
struct sockaddr_in addr;
int sock;
if ( !(hp = gethostbyname( host )) )
errExit( "Couldn't resolve host" );
memset( &addr, 0, sizeof( addr ) );
addr.sin_addr = *(struct in_addr *)hp->h_addr_list[0];
addr.sin_family = AF_INET;
addr.sin_port = htons( port );
if ( (sock = socket( AF_INET, SOCK_STREAM, IPPROTO_TCP )) < 0 )
errExit( "Couldn't create socket" );
if ( connect( sock, (struct sockaddr *)&addr, sizeof( addr ) ) < 0 )
errExit( "Couldn't connect socket" );
return( sock );
}
static int
initSession( int sock )
{
BYTE buffer[ 1024 ];
int length = 0;
buffer[ 0 ] = aes256 ? 1 : 0;
if ( ! (length = encodePK( sizeof( buffer ) - 1, &buffer[1] )) )
return( 0 );
if ( verbose )
{
printf( "Session Initiation: \n" );
hexDump( buffer, length + 1 );
}
if ( ! sendBytes( sock, buffer, length + 1 ) )
return( 0 );
if ( ! receiveBytes( sock, buffer, sizeof( buffer ), &length ) )
return( 0 );
if ( ! length )
{
printf( "No server initiation response\n" );
return( 0 );
}
if ( verbose )
{
printf( "Session Response: \n" );
hexDump( buffer, length );
}
switch( buffer[ 0 ] )
{
case 0 : aes256 = 0; break;
case 1 :
if ( aes256 ) break;
/*
** Fall through for unsupported AES key length.
*/
default:
printf( "Invalid AES key length returned: %d\n", buffer[ 0 ] );
return( 0 );
}
if ( ! decodeSK( length - 1, &buffer[1] ) )
return( 0 );
return( 1 );
}
static void
clientRequest( int sock, char *host, int port )
{
BYTE buf[ 1024 ];
int len;
snprintf( buf, sizeof( buf ), REQUEST_TEMPLATE, host, port );
len = strlen( buf );
if ( ! (len = encode( len, buf, sizeof( buf ) )) ) return;
/* Send request to server */
if ( verbose )
{
printf( "Sending %d bytes: \n", len );
hexDump( buf, len );
}
if ( ! sendBytes( sock, buf, len ) ) return;
do
{
BYTE *ptr, *end;
int echo = 0;
if ( ! receiveBytes( sock, buf, sizeof( buf ), &len ) )
return;
if ( ! len )
{
printf( "No server response\n" );
break;
}
if ( verbose )
{
printf( "Received %d bytes: \n", len );
hexDump( buf, len );
}
if ( ! (len = decode( len, buf, sizeof( buf ) )) ) return;
ptr = buf;
end = ptr + len;
*end = 0;
while( ptr < end )
{
BYTE *eol, save;
for( eol = ptr; eol < end && *eol != '\n'; eol++ ) /* DO NOTHING */;
eol++;
save = *eol;
*eol = 0;
/*
** Display response between 'BEGIN' and 'END'
*/
if ( ! echo )
{
/* Echo starts with 'BEGIN' */
if ( ! strcmp(ptr, "BEGIN\r\n") || ! strcmp(ptr, "BEGIN\n") )
echo = 1;
}
else
{
/* Echo finishes with 'END' */
if ( ! strcmp( ptr, "END\r\n" ) || ! strcmp( ptr, "END\n" ) )
break;
printf( "%s", ptr );
}
*eol = save;
ptr = eol;
}
} while( 0 );
}
static int
initCrypt()
{
const EVP_CIPHER *cipher;
/*
** TODO: Initialize random number generator.
** This does get done automatically in some cases.
*/
OpenSSL_add_all_ciphers();
if ( useECB )
{
if ( ! (cipherAes128ECB = EVP_get_cipherbyname( AES128_ECB )) )
{
printf( "Can't load AES 128 ECB\n" );
return( 0 );
}
if ( (cipherAes256ECB = EVP_get_cipherbyname( AES256_ECB )) )
cipher = cipherAes256ECB;
else
{
printf( "Can't load AES 256 CBC\n" );
cipher = cipherAes128ECB;
aes256 = 0;
}
}
else
{
if ( ! (cipherAes128CBC = EVP_get_cipherbyname( AES128_CBC )) )
{
printf( "Can't load AES 128 CBC\n" );
return( 0 );
}
if ( (cipherAes256CBC = EVP_get_cipherbyname( AES256_CBC )) )
cipher = cipherAes256CBC;
else
{
printf( "Can't load AES 256 CBC\n" );
cipher = cipherAes128CBC;
aes256 = 0;
}
}
/* TODO: determine max RSA key length */
if ( verbose )
{
printf( "AES Max Key Length: %d\n",
EVP_CIPHER_key_length( cipher ) * 8 );
}
if ( ! (rsaKey = RSA_generate_key( 1024, RSA_EXPONENT, NULL, NULL )) )
{
printf( "RSA_generate_key failed\n" );
return( 0 );
}
if ( verbose )
{
BYTE buff[ 1024 ];
int len = BN_num_bytes( rsaKey->n );
if ( len > sizeof( buff ) )
{
printf( "RSA modulus requires %d bytes\n", len );
return( 0 );
}
printf( "Xchg Exp: %d\n", (int)BN_get_word( rsaKey->e ) );
printf( "Xchg Mod: \n" );
len = BN_bn2bin( rsaKey->n, buff );
hexDump( buff, len );
}
return( 1 );
}
static void
termCrypt()
{
if ( rsaKey ) RSA_free( rsaKey );
}
static int
encodePK( int size, BYTE *buffer )
{
BYTE *ptr = buffer;
int length;
length = i2d_RSAPublicKey( rsaKey, NULL );
if ( length >= size )
{
printf( "Encoded RSA key requires %d bytes\n", length );
return( 0 );
}
length = i2d_RSAPublicKey( rsaKey, &ptr );
return( length );
}
static int
decodeSK( int length, BYTE *buffer )
{
const EVP_CIPHER *cipher;
int keylen = (aes256 ? 256 : 128) / 8;
BYTE *iv = NULL;
if ( keylen > sizeof( aesKey ) )
{
printf( "AES Key requires %d bytes (only %d provided)\n",
keylen, sizeof( aesKey ) );
return( 0 );
}
length = RSA_private_decrypt( length, buffer, aesKey,
rsaKey, RSA_PKCS1_PADDING );
if ( length <= 0 )
{
printf( "RSA decrypt failed\n" );
return( 0 );
}
if ( length != keylen )
{
printf( "Invalid length of AES key: %d (%d expected)\n",
length, keylen );
return( 0 );
}
if ( verbose )
{
printf( "Sess Key: \n" );
hexDump( aesKey, keylen );
}
if ( useECB )
cipher = aes256 ? cipherAes256ECB : cipherAes128ECB;
else
{
int i, ivLen;
cipher = aes256 ? cipherAes256CBC : cipherAes128CBC;
ivLen = EVP_CIPHER_iv_length( cipher );
if ( ivLen != EVP_CIPHER_block_size( cipher ) )
printf( "AES CBC IV length (%d) different than block size (%d)\n",
ivLen, EVP_CIPHER_block_size( cipher ) );
if ( ivLen > sizeof( cbcIV ) )
{
printf( "AES CBC IV requires %d bytes (%d provided)\n",
ivLen, sizeof( cbcIV ) );
return( 0 );
}
for( i = 0; i < ivLen; i++ ) cbcIV[ i ] = 0;
iv = cbcIV;
}
EVP_CIPHER_CTX_init( &encCtx );
EVP_CIPHER_CTX_init( &decCtx );
EVP_EncryptInit( &encCtx, cipher, aesKey, iv );
EVP_DecryptInit( &decCtx, cipher, aesKey, iv );
EVP_EncryptInit( &encCtx, cipher, aesKey, NULL );
EVP_DecryptInit( &decCtx, cipher, aesKey, NULL );
if ( verbose )
{
printf( "Key Block Size: %d\n", EVP_CIPHER_block_size( cipher ) );
printf( "Key Mode: 0x%x\n", (int)EVP_CIPHER_CTX_mode( &encCtx ) );
if ( ! useECB )
{
printf( "Key IV: \n" );
hexDump( encCtx.iv, EVP_CIPHER_CTX_iv_length( &encCtx ) );
}
}
return( 1 );
}
static int
encode( int length, BYTE *buffer, int size )
{
#ifndef CRYPT_IN_PLACE
BYTE buff[ 1024 ];
#endif
BYTE *ptr = buffer;
int olen = 0;
int block_size = EVP_CIPHER_CTX_block_size( &encCtx );
if ( (length + block_size) > size )
{
printf( "Encryption requires %d bytes\n", length + block_size );
return( 0 );
}
#ifndef CRYPT_IN_PLACE
ptr = buff;
if ( (length + block_size) > sizeof( buff ) )
{
printf( "Encryption requires %d bytes\n", length + block_size );
return( 0 );
}
#endif
EVP_EncryptUpdate( &encCtx, ptr, &length, buffer, length );
ptr += length;
olen += length;
EVP_EncryptFinal( &encCtx, ptr, &length );
olen += length;
#ifndef CRYPT_IN_PLACE
memcpy( buffer, buff, olen );
#endif
return( olen );
}
static int
decode( int length, BYTE *buffer, int size )
{
#ifndef CRYPT_IN_PLACE
BYTE buff[ 1024 ];
#endif
BYTE *ptr = buffer;
int olen = 0;
#ifndef CRYPT_IN_PLACE
ptr = buff;
if ( length > sizeof( buff ) )
{
printf( "Decryption requires %d bytes\n", length );
return( 0 );
}
#endif
EVP_DecryptUpdate( &decCtx, ptr, &length, buffer, length );
ptr += length;
olen += length;
EVP_DecryptFinal( &decCtx, ptr, &length );
olen += length;
#ifndef CRYPT_IN_PLACE
memcpy( buffer, buff, olen );
#endif
return( olen );
}
static int
sendBytes( int sock, BYTE *buff, int len )
{
while( len )
{
int actual = send( sock, buff, len, 0 );
if ( actual < 0 )
{
printf( "Error sending data: %d\n", errno );
return( 0 );
}
buff += actual;
len -= actual;
}
return( 1 );
}
static int
receiveBytes( int sock, BYTE *buff, int size, int *len )
{
int actual;
*len = 0;
actual = recv( sock, buff, size, 0 );
if ( actual < 0 )
{
printf( "Error receiving data: %d\n", errno );
return( 0 );
}
*len = actual;
return( 1 );
}
static void
hexDump( BYTE *buffer, int length )
{
int cnt, idx;
char *digits = "0123456789ABCDEF";
char line[ 100 ];
for( idx = 0; length; length -= cnt, buffer += cnt, idx += cnt )
{
BYTE *ptr;
int i;
cnt = (length > 16) ? 16 : length;
sprintf( line, "%4.4x: ", idx );
ptr = line + 6;
for( i = 0; i < cnt; i++ )
{
*ptr++ = digits[ (buffer[i] >> 4) & 0x0f ];
*ptr++ = digits[ buffer[i] & 0x0f ];
*ptr++ = (i == 7) ? ':' : ' ';
}
for( ; i < 16; i++ )
{
*ptr++ = ' ';
*ptr++ = ' ';
*ptr++ = ' ';
}
*ptr++ = ' ';
for( i = 0; i < cnt; i++ )
if ( buffer[i] < 32 || buffer[i] > 126 )
*ptr++ = '.';
else
*ptr++ = buffer[i];
*ptr = 0;
printf( "%s\n", line );
}
}
/*
** Demo Public Key OpenSSL Echo Server
**
** Establishes a listen port. Accepts an encoded RSA Public
** Key, generates an AES session key and returns it encoded
** with the public key. Reads client text lines encoded with
** session key and responds with server header follwed by
** client request between lines 'BEGIN' and 'END' (inclusive),
** all encoded with the session key.
**
** Arguments:
** -p <port> Listen port number (default 16903)
** -e AES Mode ECB
** -v Verbose
** -h Help
*/
#include <stdio.h>
#include <stdlib.h>
#include <errno.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <netdb.h>
#include <unistd.h>
#include <string.h>
#include <openssl/bn.h>
#include <openssl/rand.h>
#include <openssl/rsa.h>
#include <openssl/evp.h>
#define CRYPT_IN_PLACE
#define SSL_DFLT_PORT 16903
#define AES128_ECB "aes-128-ecb"
#define AES128_CBC "aes-128-cbc"
#define AES256_ECB "aes-256-ecb"
#define AES256_CBC "aes-256-cbc"
typedef unsigned char BYTE;
static char *RESPONSE_TEMPLATE =
"SSL Echo Server: openssl\r\n";
extern char *optarg;
static int useECB = 0;
static int verbose = 0;
static RSA *rsaKey = NULL;
static int aes256 = 1;
static BYTE aesKey[ EVP_MAX_KEY_LENGTH ];
static BYTE cbcIV[ EVP_MAX_IV_LENGTH ];
static const EVP_CIPHER *cipherAes128ECB = NULL;
static const EVP_CIPHER *cipherAes128CBC = NULL;
static const EVP_CIPHER *cipherAes256ECB = NULL;
static const EVP_CIPHER *cipherAes256CBC = NULL;
static EVP_CIPHER_CTX encCtx;
static EVP_CIPHER_CTX decCtx;
static int errExit( char * );
static int tcpListen( int );
static int initSession( int );
static void service( int );
static int initCrypt();
static void termCrypt();
static int decodePK( int, BYTE * );
static int encodeSK( int, BYTE * );
static int encode( int, BYTE *, int );
static int decode( int, BYTE *, int );
static int sendBytes( int, BYTE *, int );
static int receiveBytes( int, BYTE *, int, int * );
static void hexDump( BYTE *, int );
int
main( int argc, char **argv )
{
int c, sock_s;
int port = SSL_DFLT_PORT;
while( (c = getopt( argc, argv, "ehp:v" )) != -1 )
switch( c )
{
case 'h':
printf( "-p <port>\tListen port number (default %d)\n", SSL_DFLT_PORT );
printf( "-e\t\tAES Mode ECB\n" );
printf( "-v\t\tVerbose\n" );
exit(0);
case 'p': /* Port */
if ( ! (port = atoi( optarg )) )
errExit( "Invalid port specified" );
break;
case 'e': useECB = 1; break;
case 'v': verbose = 1; break;
}
sock_s = tcpListen( port );
while( 1 )
{
int sock_c;
if ( (sock_c = accept( sock_s, 0, 0 )) < 0 )
errExit( "Problem accepting" );
if ( fork() )
close( sock_c );
else
{
if ( initCrypt() && initSession( sock_c ) )
service( sock_c );
termCrypt();
shutdown( sock_c, 2 );
close( sock_c );
exit(0);
}
}
exit(0);
}
static int
errExit( char *str )
{
fprintf( stderr, "%s\n", str );
exit(0);
}
static int
tcpListen( int port )
{
struct sockaddr_in sin;
int sock;
int val = 1;
if ( (sock = socket( AF_INET, SOCK_STREAM, 0 )) < 0 )
errExit( "Couldn't create socket" );
memset( &sin, 0, sizeof( sin ) );
sin.sin_addr.s_addr = INADDR_ANY;
sin.sin_family = AF_INET;
sin.sin_port = htons( port );
setsockopt( sock, SOL_SOCKET, SO_REUSEADDR, &val, sizeof( val ) );
if ( bind( sock, (struct sockaddr *)&sin, sizeof( sin ) ) < 0 )
errExit( "Couldn't bind socket to port" );
listen( sock, 5 );
return( sock );
}
static int
initSession( int sock )
{
BYTE buffer[ 1024 ];
int length;
if ( ! receiveBytes( sock, buffer, sizeof( buffer ), &length ) )
return( 0 );
if ( ! length )
{
printf( "No client initiation request\n" );
return( 0 );
}
if ( verbose )
{
printf( "Session Initiation: \n" );
hexDump( buffer, length );
}
switch( buffer[ 0 ] )
{
case 0 : aes256 = 0; break;
case 1 : break;
default:
printf( "Invalid AES key length requested: %d\n", buffer[0] );
return( 0 );
}
if ( ! decodePK( length - 1, &buffer[1] ) )
return( 0 );
buffer[ 0 ] = aes256 ? 1 : 0;
if ( ! (length = encodeSK( sizeof( buffer ) - 1, &buffer[1] )) )
return( 0 );
if ( verbose )
{
printf( "Session Response: \n" );
hexDump( buffer, length + 1 );
}
if ( ! sendBytes( sock, buffer, length + 1 ) )
return( 0 );
return( 1 );
}
static void
service( int sock )
{
BYTE ibuff[ 1024 ];
BYTE obuff[ 1024 ];
BYTE *ptr, *end;
int ilen, olen;
int echo = 0;
if ( ! receiveBytes( sock, ibuff, sizeof( ibuff ), &ilen ) )
return;
if ( ! ilen )
{
printf( "No client request\n" );
return;
}
if ( verbose )
{
printf( "Received %d bytes: \n", ilen );
hexDump( ibuff, ilen );
}
if ( ! (ilen = decode( ilen, ibuff, sizeof( ibuff ) )) ) return;
ptr = ibuff;
end = ptr + ilen;
*end = 0;
printf( "Request: %s\n", ibuff );
while( ptr < end )
{
BYTE *eol, save;
for( eol = ptr; eol < end && *eol != '\n'; eol++ ) /* DO NOTHING */;
eol++;
save = *eol;
*eol = 0;
if ( ! echo )
{
/* Echo client text starting with 'BEGIN' */
if ( ! strcmp( ptr, "BEGIN\r\n" ) || ! strcmp( ptr , "BEGIN\n" ) )
{
/* Initiate server response */
sprintf( obuff, RESPONSE_TEMPLATE );
echo = 1;
}
}
if ( echo )
{
/* Echo client text back to client */
strcat( obuff, ptr );
/* Echo ends when 'END' is received */
if ( ! strcmp( ptr, "END\r\n" ) || ! strcmp( ptr, "END\n" ) )
break;
}
*eol = save;
ptr = eol;
}
printf( "Response: %s\n", obuff );
olen = strlen( obuff );
if ( ! (olen = encode( olen, obuff, sizeof( obuff ) )) ) return;
if ( verbose )
{
printf( "Sending %d bytes: \n", olen );
hexDump( obuff, olen );
}
sendBytes( sock, obuff, olen );
}
static int
initCrypt()
{
const EVP_CIPHER *cipher;
/*
** TODO: Initialize random number generator.
** This is get done automatically in some cases.
*/
OpenSSL_add_all_ciphers();
if ( useECB )
{
if ( ! (cipherAes128ECB = EVP_get_cipherbyname( AES128_ECB )) )
{
printf( "Can't load AES 128 ECB\n" );
return( 0 );
}
if ( (cipherAes256ECB = EVP_get_cipherbyname( AES256_ECB )) )
cipher = cipherAes256ECB;
else
{
printf( "Can't load AES 256 CBC\n" );
cipher = cipherAes128ECB;
aes256 = 0;
}
}
else
{
if ( ! (cipherAes128CBC = EVP_get_cipherbyname( AES128_CBC )) )
{
printf( "Can't load AES 128 CBC\n" );
return( 0 );
}
if ( (cipherAes256CBC = EVP_get_cipherbyname( AES256_CBC )) )
cipher = cipherAes256CBC;
else
{
printf( "Can't load AES 256 CBC\n" );
cipher = cipherAes128CBC;
aes256 = 0;
}
}
/* TODO: determine max RSA key length */
if ( verbose )
{
printf( "AES Max Key Length: %d\n",
EVP_CIPHER_key_length( cipher ) * 8 );
}
return( 1 );
}
static void
termCrypt()
{
if ( rsaKey ) RSA_free( rsaKey );
}
static int
decodePK( int length, BYTE *buffer )
{
const BYTE *ptr = buffer;
if ( ! (rsaKey = d2i_RSAPublicKey( NULL, &ptr, length )) )
{
printf( "RSA decode key failed\n" );
return( 0 );
}
if ( verbose )
{
BYTE buff[ 1024 ];
int len = BN_num_bytes( rsaKey->n );
if ( len > sizeof( buff ) )
{
printf( "RSA modulus requires %d bytes\n", len );
return( 0 );
}
printf( "Xchg Exp: %d\n", (int)BN_get_word( rsaKey->e ) );
printf( "Xchg Mod: \n" );
len = BN_bn2bin( rsaKey->n, buff );
hexDump( buff, len );
}
return( 1 );
}
static int
encodeSK( int size, BYTE *buffer )
{
const EVP_CIPHER *cipher;
int keylen = (aes256 ? 256 : 128) / 8;
BYTE *iv = NULL;
int len;
if ( ! RAND_bytes( aesKey, keylen ) )
{
printf( "Random number generator not initialized\n" );
return( 0 );
}
if ( verbose )
{
printf( "Sess Key: \n" );
hexDump( aesKey, keylen );
/* TODO: key info => Block size, Mode, Padding, IV */
}
if ( useECB )
cipher = aes256 ? cipherAes256ECB : cipherAes128ECB;
else
{
int i, ivLen;
cipher = aes256 ? cipherAes256CBC : cipherAes128CBC;
ivLen = EVP_CIPHER_iv_length( cipher );
if ( ivLen != EVP_CIPHER_block_size( cipher ) )
printf( "AES CBC IV length (%d) different than block size (%d)\n",
ivLen, EVP_CIPHER_block_size( cipher ) );
if ( ivLen > sizeof( cbcIV ) )
{
printf( "AES CBC IV requires %d bytes (%d provided)\n",
ivLen, sizeof( cbcIV ) );
return( 0 );
}
for( i = 0; i < ivLen; i++ ) cbcIV[ i ] = 0;
iv = cbcIV;
}
EVP_CIPHER_CTX_init( &encCtx );
EVP_CIPHER_CTX_init( &decCtx );
EVP_EncryptInit( &encCtx, cipher, aesKey, iv );
EVP_DecryptInit( &decCtx, cipher, aesKey, iv );
if ( verbose )
{
printf( "Key Block Size: %d\n", EVP_CIPHER_block_size( cipher ) );
printf( "Key Mode: 0x%x\n", (int)EVP_CIPHER_CTX_mode( &encCtx ) );
if ( ! useECB )
{
printf( "Key IV: \n" );
hexDump( encCtx.iv, EVP_CIPHER_CTX_iv_length( &encCtx ) );
}
}
if ( RSA_size( rsaKey ) > size )
{
printf( "RSA encrypt requires %d bytes\n", RSA_size( rsaKey ) );
return( 0 );
}
len = RSA_public_encrypt( keylen, aesKey, buffer,
rsaKey, RSA_PKCS1_PADDING );
if ( len <= 0 )
{
printf( "RSA encryption error\n" );
return( 0 );
}
return( len );
}
static int
encode( int length, BYTE *buffer, int size )
{
#ifndef CRYPT_IN_PLACE
BYTE buff[ 1024 ];
#endif
BYTE *ptr = buffer;
int olen = 0;
int block_size = EVP_CIPHER_CTX_block_size( &encCtx );
if ( (length + block_size) > size )
{
printf( "Encryption requires %d bytes\n", length + block_size );
return( 0 );
}
#ifndef CRYPT_IN_PLACE
ptr = buff;
if ( (length + block_size) > sizeof( buff ) )
{
printf( "Encryption requires %d bytes\n", length + block_size );
return( 0 );
}
#endif
EVP_EncryptUpdate( &encCtx, ptr, &length, buffer, length );
ptr += length;
olen += length;
EVP_EncryptFinal( &encCtx, ptr, &length );
olen += length;
#ifndef CRYPT_IN_PLACE
memcpy( buffer, buff, olen );
#endif
return( olen );
}
static int
decode( int length, BYTE *buffer, int size )
{
#ifndef CRYPT_IN_PLACE
BYTE buff[ 1024 ];
#endif
BYTE *ptr = buffer;
int olen = 0;
#ifndef CRYPT_IN_PLACE
ptr = buff;
if ( length > sizeof( buff ) )
{
printf( "Decryption requires %d bytes\n", length );
return( 0 );
}
#endif
EVP_DecryptUpdate( &decCtx, ptr, &length, buffer, length );
ptr += length;
olen += length;
EVP_DecryptFinal( &decCtx, ptr, &length );
olen += length;
#ifndef CRYPT_IN_PLACE
memcpy( buffer, buff, olen );
#endif
return( olen );
}
static int
sendBytes( int sock, BYTE *buff, int len )
{
while( len )
{
int actual = send( sock, buff, len, 0 );
if ( actual < 0 )
{
printf( "Error sending data: %d\n", errno );
return( 0 );
}
buff += actual;
len -= actual;
}
return( 1 );
}
static int
receiveBytes( int sock, BYTE *buff, int size, int *len )
{
int actual;
*len = 0;
actual = recv( sock, buff, size, 0 );
if ( actual < 0 )
{
printf( "Error receiving data: %d\n", errno );
return( 0 );
}
*len = actual;
return( 1 );
}
static void
hexDump( BYTE *buffer, int length )
{
int cnt, idx;
char *digits = "0123456789ABCDEF";
char line[ 100 ];
for( idx = 0; length; length -= cnt, buffer += cnt, idx += cnt )
{
BYTE *ptr;
int i;
cnt = (length > 16) ? 16 : length;
sprintf( line, "%4.4x: ", idx );
ptr = line + 6;
for( i = 0; i < cnt; i++ )
{
*ptr++ = digits[ (buffer[i] >> 4) & 0x0f ];
*ptr++ = digits[ buffer[i] & 0x0f ];
*ptr++ = (i == 7) ? ':' : ' ';
}
for( ; i < 16; i++ )
{
*ptr++ = ' ';
*ptr++ = ' ';
*ptr++ = ' ';
}
*ptr++ = ' ';
for( i = 0; i < cnt; i++ )
if ( buffer[i] < 32 || buffer[i] > 126 )
*ptr++ = '.';
else
*ptr++ = buffer[i];
*ptr = 0;
printf( "%s\n", line );
}
}
/*
** Demo Public Key Windows Echo Client
**
** Connects to a PK Echo Server. Generates RSA Key and
** sends encoded public key to server. Response is a
** session key encoded with the public key. Sends client
** request, encode with session key, consisting of
** header and text to be echo'd between lines 'BEGIN'
** and 'END' (inclusive). Reads server response, also
** encoded with session key, consisting of header and
** displays echo'd text between lines 'BEGIN' and 'END'
** (exclusive).
**
** Arguments:
** -t <host> Target host name (default 'localhost')
** -p <port> Target port number (default 16903)
** -e AES Mode ECB
** -v Verbose
** -h Help
**
*/
#include <windows.h>
#include <winsock.h>
#include <stdio.h>
#include <stdlib.h>
#include <wincrypt.h>
#pragma comment(lib, "ws2_32")
#pragma comment(lib, "advapi32")
#pragma comment(lib, "crypt32")
#define DEFAULT_HOST "localhost"
#define DEFAULT_PORT 16903
static DWORD ProvType = PROV_RSA_AES;
static DWORD XchgKeyType = CALG_RSA_KEYX;
static DWORD XchgKeyFlags = RSA1024BIT_KEY;
static DWORD SessKey128 = CALG_AES_128;
static DWORD SessKey256 = CALG_AES_256;
static DWORD sessKeyMode = CRYPT_MODE_CBC;
static DWORD sessKeyPadding = PKCS5_PADDING;
static char *REQUEST_TEMPLATE =
"PK Echo Client: Windows CryptoAPI\r\n"
"BEGIN\r\n"
"PK Client/Server Echo Test\r\n"
"Host: %s:%d\r\n"
"END\r\n";
static BOOL verbose = FALSE;
static HCRYPTPROV provider = 0;
static HCRYPTKEY xchgKey = 0;
static HCRYPTKEY sessKey = 0;
static DWORD rsaMaxKey = 0;
static DWORD aesMaxKey = 0;
static SOCKET tcpConnect( char *, u_short );
static BOOL initSession( SOCKET );
static BOOL clientRequest( SOCKET, char *, u_short );
static BOOL initCrypt();
static void termCrypt();
static DWORD encodePK( DWORD, BYTE * );
static BOOL decodeSK( BOOL, DWORD, BYTE * );
static DWORD encode( DWORD, BYTE *, DWORD );
static DWORD decode( DWORD, BYTE *, DWORD );
static void keyInfo( HCRYPTKEY );
static BOOL sendBytes( SOCKET, BYTE *, DWORD );
static BOOL receiveBytes( SOCKET, BYTE *, DWORD, DWORD * );
static void hexDump( DWORD, BYTE * );
void
main( int argc, char **argv )
{
SOCKET sock;
int arg;
char *host = DEFAULT_HOST;
u_short port = (u_short)DEFAULT_PORT;
for( arg = 1; arg < argc; arg++ )
{
if ( argv[ arg ][0] != '-' )
{
fprintf( stderr, "Invalid command line argument: %s\n", argv[ arg ] );
exit(1);
}
switch( argv[ arg ][1] )
{
case 'h':
printf( "-t <host>\tTarget host name (default 'localhost')\n" );
printf( "-p <port>\tTarget port number (default 16903)\n" );
printf( "-e\t\tAES Mode ECB\n" );
printf( "-v\t\tVerbose\n" );
exit(0);
case 't' : host = argv[ ++arg ]; break;
case 'p' : port = atoi( argv[ ++arg ] ); break;
case 'e' : sessKeyMode = CRYPT_MODE_ECB; break;
case 'v' : verbose = TRUE; break;
default:
fprintf( stderr, "Invalid command line flag: %s\n", argv[ arg ] );
exit(1);
}
}
if ( ! initCrypt() ) exit(1);
sock = tcpConnect( host, port );
if ( ! initSession( sock ) ) exit(1);
if ( ! clientRequest( sock, host, port ) ) exit(1);
shutdown( sock, 2 );
closesocket( sock );
if ( WSACleanup() == SOCKET_ERROR )
fprintf( stderr, "Problem with socket cleanup\n" );
termCrypt();
exit(0);
}
static SOCKET
tcpConnect( char *host, u_short port )
{
WSADATA wsaData;
unsigned long address;
SOCKADDR_IN sin;
SOCKET sock;
if ( WSAStartup( 0x0101, &wsaData ) )
{
fprintf( stderr, "Could not initialize winsock\n" );
exit(1);
}
address = inet_addr( host );
if ( address == INADDR_NONE )
{
struct hostent *hostent;
if ( ! (hostent = gethostbyname( host )) )
{
fprintf( stderr, "Unable to resolve host name: '%s'\n", host );
exit(1);
}
memcpy( (char FAR *)&address, hostent->h_addr, hostent->h_length );
}
sock = socket( PF_INET, SOCK_STREAM, 0 );
if ( sock == INVALID_SOCKET )
{
fprintf( stderr, "Unable to create socket\n" );
exit(1);
}
sin.sin_family = AF_INET;
sin.sin_addr.s_addr = address;
sin.sin_port = htons( port );
if ( connect( sock, (LPSOCKADDR) &sin, sizeof(sin) ) )
{
closesocket( sock );
fprintf( stderr, "Connect failed\n" );
exit(1);
}
return( sock );
}
static BOOL
initSession( SOCKET sock )
{
BYTE buffer[ 1024 ];
DWORD length;
BOOL aes256 = FALSE;
aes256 = (aesMaxKey >= 256);
buffer[ 0 ] = aes256 ? 1 : 0;
if ( ! (length = encodePK( sizeof( buffer ) - 1, &buffer[1] )) )
return( FALSE );
if ( verbose )
{
printf( "Session Initiation: \n" );
hexDump( length + 1, buffer );
}
if ( ! sendBytes( sock, buffer, length + 1 ) )
return( FALSE );
if ( ! receiveBytes( sock, buffer, sizeof( buffer ), &length ) )
return( FALSE );
if ( ! length )
{
fprintf( stderr, "No session response from server\n" );
return( FALSE );
}
if ( verbose )
{
printf( "Session Response: \n" );
hexDump( length, buffer );
}
switch( buffer[ 0 ] )
{
case 0 : aes256 = FALSE; break;
case 1 :
if ( aes256 ) break;
/*
** Fall through for unsupported AES key length.
*/
default:
fprintf( stderr, "Invalid AES key length returned: %d\n", buffer[0] );
return( FALSE );
}
if ( ! decodeSK( aes256, length - 1, &buffer[1] ) )
return( FALSE );
return( TRUE );
}
static BOOL
clientRequest( SOCKET sock, char *host, u_short port )
{
BYTE buffer[ 1024 ];
BYTE *ptr, *end;
DWORD length;
BOOL echo = FALSE;
sprintf_s( buffer, sizeof( buffer ), REQUEST_TEMPLATE, host, port );
length = strlen( buffer );
if ( ! (length = encode( length, buffer, sizeof( buffer ) )) )
return( FALSE );
if ( verbose )
{
printf( "Client Request: \n" );
hexDump( length, buffer );
}
if ( ! sendBytes( sock, buffer, length ) )
return( FALSE );
do
{
if ( ! receiveBytes( sock, buffer, sizeof( buffer ), &length ) )
return( FALSE );
if ( ! length )
{
fprintf( stderr, "No response from server\n" );
break;
}
if ( verbose )
{
printf( "Server Response: \n" );
hexDump( length, buffer );
}
if ( ! (length = decode( length, buffer, sizeof( buffer ) )) )
return( FALSE );
ptr = buffer;
end = ptr + length;
*end = 0;
while( ptr < end )
{
BYTE *eol, save;
for( eol = ptr; *eol && *eol != '\n'; eol++ ) /* DO NOTHING */;
eol++;
save = *eol;
*eol = 0;
/*
** Display response between 'BEGIN' and 'END'
*/
if ( ! echo )
{
/* Echo starts with 'BEGIN' */
if ( ! strcmp( ptr, "BEGIN\r\n" ) || ! strcmp( ptr, "BEGIN\n" ) )
echo = TRUE;
}
else
{
/* Echo finishes with 'END' */
if ( ! strcmp( ptr, "END\r\n" ) || ! strcmp( ptr, "END\n" ) )
break;
printf( "%s", ptr );
}
*eol = save;
ptr = eol;
}
} while( FALSE );
return( TRUE );
}
static BOOL
initCrypt()
{
BYTE buffer[ 1024 ];
DWORD size, flags;
if ( ! CryptAcquireContext( &provider, NULL, NULL, ProvType, CRYPT_VERIFYCONTEXT ) )
{
fprintf( stderr, "CryptAcquireContext() failed: 0x%x\n", GetLastError() );
return( FALSE );
}
if ( ! CryptGenKey( provider, XchgKeyType, XchgKeyFlags, &xchgKey ) )
{
fprintf( stderr, "CryptGenKey() failed: 0x%x\n", GetLastError() );
return( FALSE );
}
for( flags = CRYPT_FIRST; ; flags = CRYPT_NEXT )
{
PROV_ENUMALGS_EX *info = (PROV_ENUMALGS_EX *)buffer;
size = sizeof( buffer );
if ( ! CryptGetProvParam( provider, PP_ENUMALGS_EX, buffer, &size, flags ) )
{
DWORD status = GetLastError();
if ( status == ERROR_NO_MORE_ITEMS ) break;
fprintf( stderr, "CryptGetProvParam() failed: 0x%x\n", GetLastError() );
return( FALSE );
}
switch( info->aiAlgid )
{
case CALG_RSA_KEYX : rsaMaxKey = max( rsaMaxKey, info->dwMaxLen ); break;
case CALG_AES_128 :
case CALG_AES_192 :
case CALG_AES_256 : aesMaxKey = max( aesMaxKey, info->dwMaxLen ); break;
}
}
if ( verbose )
{
struct pkb
{
PUBLICKEYSTRUC hdr;
RSAPUBKEY rsa;
BYTE mod[1];
} *rsaKey= (struct pkb *)buffer;
printf( "RSA Max Key Length: %d\n", rsaMaxKey );
printf( "AES Max Key Length: %d\n", aesMaxKey );
if ( ! CryptExportKey( xchgKey, 0, PUBLICKEYBLOB, 0, NULL, &size ) )
{
fprintf( stderr, "CryptExportKey() [1] failed: 0x%x\n", GetLastError() );
return( FALSE );
}
if ( size > sizeof( buffer ) )
{
fprintf( stderr, "CryptExportKey() requires %d bytes\n", size );
return( FALSE );
}
size = sizeof( buffer );
if ( ! CryptExportKey( xchgKey, 0, PUBLICKEYBLOB, 0, (BYTE *)rsaKey, &size ) )
{
fprintf( stderr, "CryptExportKey() [2] failed: 0x%x\n", GetLastError() );
return( FALSE );
}
printf( "Xchg Exp: %d\n", rsaKey->rsa.pubexp );
printf( "Xchg Mod: \n" );
hexDump( rsaKey->rsa.bitlen / 8, rsaKey->mod );
}
if ( rsaMaxKey < 1024 || aesMaxKey < 128 )
{
fprintf( stderr, "Invalid minimum key length\n" );
return( FALSE );
}
return( TRUE );
}
static void
termCrypt()
{
if ( xchgKey )
{
if ( ! CryptDestroyKey( xchgKey ) )
fprintf( stderr, "CryptDestroyKey() failed: 0x%x\n", GetLastError() );
}
if ( sessKey )
{
if ( ! CryptDestroyKey( sessKey ) )
fprintf( stderr, "CryptDestroyKey() failed: 0x%x\n", GetLastError() );
}
if ( provider )
{
if ( ! CryptReleaseContext( provider, 0 ) )
fprintf( stderr, "CryptReleaseContext() failed: 0x%x\n", GetLastError() );
}
}
static DWORD
encodePK( DWORD size, BYTE *buffer )
{
CERT_PUBLIC_KEY_INFO *keyInfo = (CERT_PUBLIC_KEY_INFO *)buffer;
DWORD length;
if ( ! CryptExportPublicKeyInfo( provider, AT_KEYEXCHANGE,
X509_ASN_ENCODING, NULL, &length ) )
{
fprintf( stderr, "CryptExportPublicKeyInfo() [1] failed: 0x%x\n", GetLastError() );
return( 0 );
}
if ( length > size )
{
fprintf( stderr, "CryptExportPublicKeyInfo() requires %d bytes\n", length );
return( 0 );
}
length = size;
if ( ! CryptExportPublicKeyInfo( provider, AT_KEYEXCHANGE,
X509_ASN_ENCODING, keyInfo, &length ) )
{
fprintf( stderr, "CryptExportPublicKeyInfo() [2] failed: 0x%x\n", GetLastError() );
return( 0 );
}
length = keyInfo->PublicKey.cbData;
memmove( buffer, keyInfo->PublicKey.pbData, length );
return( length );
}
static BOOL
decodeSK( BOOL aes256, DWORD keylen, BYTE *keybuf )
{
BYTE buffer[ 1024 ];
DWORD i, size, length;
struct skb
{
PUBLICKEYSTRUC hdr;
ALG_ID algId;
BYTE key[1];
} *expKey = (struct skb *)buffer;
struct ptkb
{
PUBLICKEYSTRUC hdr;
DWORD keysize;
BYTE key[1];
} *txtKey = (struct ptkb *)buffer;
length = (expKey->key - buffer) + keylen;
if ( length > sizeof( buffer ) )
{
fprintf( stderr, "CryptImportKey() requires %d bytes\n", length );
return( FALSE );
}
expKey->hdr.bType = SIMPLEBLOB;
expKey->hdr.bVersion = CUR_BLOB_VERSION;
expKey->hdr.reserved = 0;
expKey->hdr.aiKeyAlg = aes256 ? SessKey256 : SessKey128;
expKey->algId = XchgKeyType;
/*
** NOTE: it appears that the encoded key is byte swapped compared
** to external standards. There is a cryptic reference to a
** ReverseMemCopy() function in the RSA/SChannel server master
** key creation example. Also, the internal RSA modulus is
** byte swapped compared to the X509 encoding. This swap is
** required to interoperate with JavaSSE.
*/
for( i = 0; i < keylen; i++ )
expKey->key[i] = keybuf[ keylen - i - 1 ];
if ( ! CryptImportKey( provider, (BYTE *)expKey, length, xchgKey, CRYPT_EXPORTABLE, &sessKey ) )
{
fprintf( stderr, "CryptImportKey() failed: 0x%x\n", GetLastError() );
return( FALSE );
}
if ( ! CryptSetKeyParam( sessKey, KP_MODE, (BYTE *)&sessKeyMode, 0 ) )
{
fprintf( stderr, "CryptSetKeyParam() MODE failed: 0x%x\n", GetLastError() );
return( 0 );
}
if ( ! CryptSetKeyParam( sessKey, KP_PADDING, (BYTE *)&sessKeyPadding, 0 ) )
{
fprintf( stderr, "CryptSetKeyParam() PADDING failed: 0x%x\n", GetLastError() );
return( 0 );
}
if ( sessKeyMode == CRYPT_MODE_CBC )
{
size = sizeof( length );
if ( ! CryptGetKeyParam( sessKey, KP_BLOCKLEN, (BYTE *)&length, &size, 0 ) )
{
fprintf( stderr, "CryptGetKeyParam() BLOCKLEN failed: 0x%x\n", GetLastError() );
exit( 1 );
}
length /= 8; /* Bits -> bytes */
for( i = 0; i < length; i++ ) buffer[i] = 0;
if ( ! CryptSetKeyParam( sessKey, KP_IV, buffer, 0 ) )
{
fprintf( stderr, "CryptSetKeyParam() PADDING failed: 0x%x\n", GetLastError() );
return( 0 );
}
}
if ( verbose )
{
if ( ! CryptExportKey( sessKey, 0, PLAINTEXTKEYBLOB, 0, NULL, &size ) )
{
fprintf( stderr, "CryptExportKey() [1] failed: 0x%x\n", GetLastError() );
return( FALSE );
}
if ( size > sizeof( buffer ) )
{
fprintf( stderr, "CryptExportKey() requires %d bytes\n", size );
return( FALSE );
}
length = sizeof( buffer );
if ( ! CryptExportKey( sessKey, 0, PLAINTEXTKEYBLOB, 0, (BYTE *)txtKey, &length ) )
{
fprintf( stderr, "CryptExportKey() [2] failed: 0x%x\n", GetLastError() );
return( FALSE );
}
printf( "Sess Key: \n" );
hexDump( txtKey->keysize, txtKey->key );
keyInfo( sessKey );
}
return( TRUE );
}
static DWORD
encode( DWORD length, BYTE *buffer, DWORD bufsiz )
{
DWORD size = length;
if ( ! CryptEncrypt( sessKey, 0, TRUE, 0, NULL, &size, bufsiz ) )
{
fprintf( stderr, "CryptEncrypt() [1] failed: 0x%x\n", GetLastError() );
return( 0 );
}
if ( size > bufsiz )
{
fprintf( stderr, "CryptEncrypt() requires %d bytes\n", size );
return( 0 );
}
if ( ! CryptEncrypt( sessKey, 0, TRUE, 0, buffer, &length, bufsiz ) )
{
fprintf( stderr, "CryptEncrypt() [2] failed: 0x%x\n", GetLastError() );
return( 0 );
}
return( length );
}
static DWORD
decode( DWORD length, BYTE *buffer, DWORD bufsiz )
{
if ( ! CryptDecrypt( sessKey, 0, TRUE, 0, buffer, &length ) )
{
fprintf( stderr, "CryptDecrypt() failed: 0x%x\n", GetLastError() );
return( 0 );
}
return( length );
}
static void
keyInfo( HCRYPTKEY key )
{
BYTE buffer[ 1024 ];
DWORD length;
length = sizeof( buffer );
if ( ! CryptGetKeyParam( key, KP_BLOCKLEN, buffer, &length, 0 ) )
{
fprintf( stderr, "CryptGetKeyParam BLOCKLEN failed: 0x%x\n", GetLastError() );
exit( 1 );
}
printf( "Key Block Size: %d\n", *(DWORD *)buffer / 8 );
length = sizeof( buffer );
if ( ! CryptGetKeyParam( key, KP_MODE, buffer, &length, 0 ) )
{
fprintf( stderr, "CryptGetKeyParam MODE failed: 0x%x\n", GetLastError() );
exit( 1 );
}
printf( "Key Mode: %d\n", *(DWORD *)buffer );
length = sizeof( buffer );
if ( ! CryptGetKeyParam( key, KP_PADDING, buffer, &length, 0 ) )
{
fprintf( stderr, "CryptGetKeyParam PADDING failed: 0x%x\n", GetLastError() );
exit( 1 );
}
printf( "Key Padding Type: %d\n", *(DWORD *)buffer );
length = sizeof( buffer );
if ( ! CryptGetKeyParam( key, KP_IV, buffer, &length, 0 ) )
{
fprintf( stderr, "CryptGetKeyParam PADDING failed: 0x%x\n", GetLastError() );
exit( 1 );
}
printf( "Key IV: \n" );
hexDump( length, buffer );
}
static BOOL
sendBytes( SOCKET sock, BYTE *buff, DWORD len )
{
while( len )
{
int actual = send( sock, (const char *)buff, len, 0 );
if ( actual == SOCKET_ERROR )
{
fprintf( stderr, "send failed: %u\n", GetLastError() );
return( FALSE );
}
buff += actual;
len -= (DWORD)actual;
}
return( TRUE );
}
static BOOL
receiveBytes( SOCKET sock, BYTE *buff, DWORD size, DWORD *len )
{
int actual;
*len = 0;
actual = recv( sock, (char *)buff, size, 0 );
if ( actual == SOCKET_ERROR )
{
fprintf( stderr, "recv failed: %u\n", GetLastError() );
return( FALSE );
}
if ( ! actual )
{
fprintf( stderr, "recv returned no data\n" );
return( FALSE );
}
*len = (DWORD)actual;
return( TRUE );
}
static void
hexDump( DWORD length, BYTE *buffer )
{
DWORD i,count,index;
CHAR digits[]="0123456789ABCDEF";
CHAR line[100];
char cnt;
for( index = 0; length; length -= count, buffer += count, index += count )
{
count = (length > 16) ? 16 : length;
sprintf_s( line, 100, "%4.4x ",index );
cnt = 6;
for( i = 0; i < count; i++ )
{
line[cnt++] = digits[buffer[i] >> 4];
line[cnt++] = digits[buffer[i] & 0x0f];
line[cnt++] = (i == 7) ? ':' : ' ';
}
for( ; i < 16; i++ )
{
line[cnt++] = ' ';
line[cnt++] = ' ';
line[cnt++] = ' ';
}
line[cnt++] = ' ';
for( i = 0; i < count; i++ )
if ( buffer[i] < 32 || buffer[i] > 126 )
line[cnt++] = '.';
else
line[cnt++] = buffer[i];
line[cnt++] = 0;
printf( "%s\n", line );
}
}
/*
** Demo Public Key Windows Echo Server
**
** Establishes a listen port. Accepts an encoded RSA Public Key,
** generates an AES session key and returns it encoded with the
** RSA Public Key. Reads client text lines encoded with the
** session key and responds with server header followed by the
** client request between lines 'BEGIN' and 'END' (inclusive) -
** all encoded with the session key.
**
** Arguments:
** -p <port> Listen port number (default 16903)
** -e AES Mode ECB
** -v Verbose
** -h Help
*/
#include <windows.h>
#include <winsock.h>
#include <stdio.h>
#include <stdlib.h>
#include <wincrypt.h>
#pragma comment(lib, "ws2_32")
#pragma comment(lib, "advapi32")
#pragma comment(lib, "crypt32")
#define DEFAULT_PORT 16903
static DWORD ProvType = PROV_RSA_AES;
static DWORD XchgKeyType = CALG_RSA_KEYX;
static DWORD SessKey128 = CALG_AES_128;
static DWORD SessKey256 = CALG_AES_256;
static DWORD sessKeyMode = CRYPT_MODE_CBC;
static DWORD sessKeyPadding = PKCS5_PADDING;
static char *RESPONSE_TEMPLATE =
"PK Echo Server: Windows CryptoAPI\r\n";
static BOOL verbose = FALSE;
static HCRYPTPROV provider = 0;
static HCRYPTKEY xchgKey = 0;
static HCRYPTKEY sessKey = 0;
static DWORD rsaMaxKey = 0;
static DWORD aesMaxKey = 0;
static SOCKET tcpListen( u_short );
static BOOL initSession( SOCKET );
static BOOL service( SOCKET );
static BOOL initCrypt();
static void termCrypt();
static BOOL decodePK( DWORD, BYTE * );
static DWORD encodeSK( BOOL, DWORD, BYTE * );
static DWORD encode( DWORD, BYTE *, DWORD );
static DWORD decode( DWORD, BYTE *, DWORD );
static void keyInfo( HCRYPTKEY );
static BOOL sendBytes( SOCKET, BYTE *, DWORD );
static BOOL receiveBytes( SOCKET, BYTE *, DWORD, DWORD * );
static void hexDump( DWORD, BYTE * );
void
main( int argc, char **argv )
{
SOCKET sock_s;
int arg;
u_short port = (u_short)DEFAULT_PORT;
for( arg = 1; arg < argc; arg++ )
{
if ( argv[ arg ][0] != '-' )
{
fprintf( stderr, "Invalid command line argument: %s\n", argv[ arg ] );
exit(1);
}
switch( argv[ arg ][1] )
{
case 'h':
printf( "-p <port>\tListen port number (default %d)\n", DEFAULT_PORT );
printf( "-e\t\tAES Mode ECB\n" );
printf( "-v\t\tVerbose\n" );
exit(0);
case 'p': port = atoi( argv[ ++arg ] ); break;
case 'e': sessKeyMode = CRYPT_MODE_ECB; break;
case 'v': verbose = TRUE; break;
default:
fprintf( stderr, "Invalid command line flag: %s\n", argv[ arg ] );
exit(1);
}
}
sock_s = tcpListen( port );
while( TRUE )
{
SOCKET sock_c;
if ( (sock_c = accept( sock_s, NULL, NULL )) == INVALID_SOCKET )
{
fprintf( stderr, "accept failed: %u\n", GetLastError() );
exit(1);
}
if ( ! initCrypt() ) exit(1);
if ( ! initSession( sock_c ) ) exit(1);
if ( ! service( sock_c ) ) exit(1);
shutdown( sock_c, 2 );
closesocket( sock_c );
termCrypt();
}
WSACleanup();
exit(0);
}
static SOCKET
tcpListen( u_short port )
{
WSADATA wsaData;
SOCKET sock;
SOCKADDR_IN sockIn;
if ( WSAStartup( 0x0101, &wsaData ) )
{
fprintf ( stderr, "Could not initialize winsock: \n" );
exit(1);
}
sock = socket( PF_INET, SOCK_STREAM, 0 );
if ( sock == INVALID_SOCKET )
{
fprintf( stderr, "Failed to create socket: %u\n", GetLastError() );
exit(1);
}
sockIn.sin_family = AF_INET;
sockIn.sin_addr.s_addr = 0;
sockIn.sin_port = htons( port );
if ( bind( sock, (LPSOCKADDR)&sockIn, sizeof( sockIn ) ) == SOCKET_ERROR )
{
fprintf( stderr, "bind failed: %u\n", GetLastError() );
exit(1);
}
if ( listen( sock, 1 ) == SOCKET_ERROR )
{
fprintf( stderr, "Listen failed: %u\n", GetLastError() );
exit(1);
}
return( sock );
}
static BOOL
initSession( SOCKET sock )
{
DWORD length;
BYTE buffer[ 1024 ];
BOOL aes256;
if ( ! receiveBytes( sock, buffer, sizeof( buffer ), &length ) )
return( FALSE );
if ( ! length )
{
fprintf( stderr, "No session init from client\n" );
return( FALSE );
}
if ( verbose )
{
printf( "Session Initiation: \n" );
hexDump( length, buffer );
}
switch( buffer[ 0 ] )
{
case 0 : aes256 = FALSE; break;
case 1 : aes256 = (aesMaxKey >= 256); break;
default:
fprintf( stderr, "Invalid AES key length requested: %d\n", buffer[0] );
return( FALSE );
}
buffer[ 0 ] = aes256 ? 1 : 0;
if ( ! decodePK( length - 1, &buffer[1] ) )
return( FALSE );
if ( ! (length = encodeSK( aes256, sizeof( buffer ) - 1, &buffer[1] )) )
return( FALSE );
if ( verbose )
{
printf( "Session Response: \n" );
hexDump( length + 1, buffer );
}
if ( ! sendBytes( sock, buffer, length + 1 ) )
return( FALSE );
return( TRUE );
}
static BOOL
service( SOCKET sock )
{
DWORD ilen, olen;
BYTE *ptr, *end;
BYTE ibuff[ 1024 ];
BYTE obuff[ 1024 ];
BOOL echo = FALSE;
if ( ! receiveBytes( sock, ibuff, sizeof( ibuff ), &ilen ) )
return( FALSE );
if ( ! ilen )
{
fprintf( stderr, "No request from client\n" );
return( FALSE );
}
if ( verbose )
{
printf( "Client Request: \n" );
hexDump( ilen, ibuff );
}
if ( ! (ilen = decode( ilen, ibuff, sizeof( ibuff ) )) )
return( FALSE );
ptr = ibuff;
end = ptr + ilen;
*end = 0;
printf( "Request: %s\n", ibuff );
while( ptr < end )
{
BYTE *eol, save;
for( eol = ptr; *eol && *eol != '\n'; eol++ ) /* DO NOTHING */;
eol++;
save = *eol;
*eol = 0;
if ( ! echo )
{
/* Echo client text starting with 'BEGIN' */
if ( ! strcmp( ptr, "BEGIN\r\n" ) || ! strcmp( ptr, "BEGIN\n" ) )
{
/* Initiate server response */
sprintf_s( obuff, sizeof( obuff ), RESPONSE_TEMPLATE );
echo = TRUE;
}
}
if ( echo )
{
/* Echo client text back to client */
strcat_s( obuff, sizeof( obuff ), ptr );
/* Echo ends when 'END' is received */
if ( !strcmp( ptr, "END\r\n" ) || ! strcmp( ptr, "END\n" ) )
break;
}
*eol = save;
ptr = eol;
}
printf( "Response: %s\n", obuff );
olen = strlen( obuff );
if ( ! (olen = encode( olen, obuff, sizeof( obuff ) )) )
return( FALSE );
if ( verbose )
{
printf( "Server Response: \n" );
hexDump( olen, obuff );
}
if ( ! sendBytes( sock, obuff, olen ) )
return( FALSE );
return( TRUE );
}
static BOOL
initCrypt()
{
BYTE buffer[ 1024 ];
DWORD size, flags;
if ( ! CryptAcquireContext( &provider, NULL, NULL, ProvType, CRYPT_VERIFYCONTEXT ) )
{
fprintf( stderr, "CryptAcquireContext() failed: 0x%x\n", GetLastError() );
return( FALSE );
}
for( flags = CRYPT_FIRST; ; flags = CRYPT_NEXT )
{
PROV_ENUMALGS_EX *info = (PROV_ENUMALGS_EX *)buffer;
size = sizeof( buffer );
if ( ! CryptGetProvParam( provider, PP_ENUMALGS_EX, buffer, &size, flags ) )
{
DWORD status = GetLastError();
if ( status == ERROR_NO_MORE_ITEMS ) break;
fprintf( stderr, "CryptGetProvParam() failed: 0x%x\n", GetLastError() );
return( FALSE );
}
switch( info->aiAlgid )
{
case CALG_RSA_KEYX : rsaMaxKey = max( rsaMaxKey, info->dwMaxLen ); break;
case CALG_AES_128 :
case CALG_AES_192 :
case CALG_AES_256 : aesMaxKey = max( aesMaxKey, info->dwMaxLen ); break;
}
}
if ( verbose )
{
printf( "RSA Max Key Length: %d\n", rsaMaxKey );
printf( "AES Max Key Length: %d\n", aesMaxKey );
}
if ( rsaMaxKey < 1024 || aesMaxKey < 128 )
{
fprintf( stderr, "Invalid minimum key length\n" );
return( FALSE );
}
return( TRUE );
}
static void
termCrypt()
{
if ( xchgKey )
{
if ( ! CryptDestroyKey( xchgKey ) )
fprintf( stderr, "CryptDestroyKey() failed: 0x%x\n", GetLastError() );
}
if ( sessKey )
{
if ( ! CryptDestroyKey( sessKey ) )
fprintf( stderr, "CryptDestroyKey() failed: 0x%x\n", GetLastError() );
}
if ( provider )
{
if ( ! CryptReleaseContext( provider, 0 ) )
fprintf( stderr, "CryptReleaseContext() failed: 0x%x\n", GetLastError() );
}
}
static BOOL
decodePK( DWORD keylen, BYTE *keybuf )
{
CERT_PUBLIC_KEY_INFO keyInfo;
keyInfo.Algorithm.pszObjId = szOID_RSA_RSA;
keyInfo.Algorithm.Parameters.cbData = 0;
keyInfo.Algorithm.Parameters.pbData = NULL;
keyInfo.PublicKey.cbData = keylen;
keyInfo.PublicKey.pbData = keybuf;
keyInfo.PublicKey.cUnusedBits = 0;
if ( ! CryptImportPublicKeyInfo( provider, X509_ASN_ENCODING, &keyInfo, &xchgKey ) )
{
fprintf( stderr, "CryptImportPublicKeyInfo() failed: 0x%x\n", GetLastError() );
return( FALSE );
}
if ( verbose )
{
BYTE buffer[ 1024 ];
DWORD size;
struct pkb
{
PUBLICKEYSTRUC hdr;
RSAPUBKEY rsa;
BYTE mod[1];
} *rsaKey = (struct pkb *)buffer;
if ( ! CryptExportKey( xchgKey, 0, PUBLICKEYBLOB, 0, NULL, &size ) )
{
fprintf( stderr, "CryptExportKey() [1] failed: 0x%x\n", GetLastError() );
return( FALSE );
}
if ( size > sizeof( buffer ) )
{
fprintf( stderr, "CryptExportKey() requires %d bytes\n", size );
return( FALSE );
}
size = sizeof( buffer );
if ( ! CryptExportKey( xchgKey, 0, PUBLICKEYBLOB, 0, (BYTE *)rsaKey, &size ) )
{
fprintf( stderr, "CryptExportKey() [2] failed: 0x%x\n", GetLastError() );
return( FALSE );
}
printf( "Xchg Exp: %d\n", rsaKey->rsa.pubexp );
printf( "Xchg Mod: \n" );
hexDump( rsaKey->rsa.bitlen / 8, rsaKey->mod );
}
return( TRUE );
}
static DWORD
encodeSK( BOOL aes256, DWORD bufsiz, BYTE *keybuf )
{
BYTE buffer[ 1024 ];
DWORD i, size, length;
struct ptkb
{
PUBLICKEYSTRUC hdr;
DWORD keysize;
BYTE key[1];
} *txtKey = (struct ptkb *)buffer;
struct skb
{
PUBLICKEYSTRUC hdr;
ALG_ID algId;
BYTE key[1];
} *expKey = (struct skb *)buffer;
if ( ! CryptGenKey( provider, aes256 ? SessKey256 : SessKey128, CRYPT_EXPORTABLE, &sessKey ) )
{
fprintf( stderr, "CryptGenKey() failed: 0x%x\n", GetLastError() );
return( 0 );
}
if ( ! CryptSetKeyParam( sessKey, KP_MODE, (BYTE *)&sessKeyMode, 0 ) )
{
fprintf( stderr, "CryptSetKeyParam() MODE failed: 0x%x\n", GetLastError() );
return( 0 );
}
if ( ! CryptSetKeyParam( sessKey, KP_PADDING, (BYTE *)&sessKeyPadding, 0 ) )
{
fprintf( stderr, "CryptSetKeyParam() PADDING failed: 0x%x\n", GetLastError() );
return( 0 );
}
if ( sessKeyMode == CRYPT_MODE_CBC )
{
size = sizeof( length );
if ( ! CryptGetKeyParam( sessKey, KP_BLOCKLEN, (BYTE *)&length, &size, 0 ) )
{
fprintf( stderr, "CryptGetKeyParam() BLOCKLEN failed: 0x%x\n", GetLastError() );
exit( 1 );
}
length /= 8; /* Bits -> bytes */
for( i = 0; i < length; i++ ) buffer[i] = 0;
if ( ! CryptSetKeyParam( sessKey, KP_IV, buffer, 0 ) )
{
fprintf( stderr, "CryptSetKeyParam() PADDING failed: 0x%x\n", GetLastError() );
return( 0 );
}
}
if ( verbose )
{
if ( ! CryptExportKey( sessKey, 0, PLAINTEXTKEYBLOB, 0, NULL, &size ) )
{
fprintf( stderr, "CryptExportKey() [1] failed: 0x%x\n", GetLastError() );
return( 0 );
}
if ( size > sizeof( buffer ) )
{
fprintf( stderr, "CryptExportKey() requires %d bytes\n", size );
return( 0 );
}
length = sizeof( buffer );
if ( ! CryptExportKey( sessKey, 0, PLAINTEXTKEYBLOB, 0, (BYTE *)txtKey, &length ) )
{
fprintf( stderr, "CryptExportKey() [2] failed: 0x%x\n", GetLastError() );
return( 0 );
}
printf( "Sess Key: \n" );
hexDump( txtKey->keysize, txtKey->key );
keyInfo( sessKey );
}
if ( ! CryptExportKey( sessKey, xchgKey, SIMPLEBLOB, 0, NULL, &size ) )
{
fprintf( stderr, "CryptExportKey() [1] failed: 0x%x\n", GetLastError() );
return( 0 );
}
if ( size > sizeof( buffer ) )
{
fprintf( stderr, "CryptExportKey() requires %d bytes\n", size );
return( 0 );
}
length = sizeof( buffer );
if ( ! CryptExportKey( sessKey, xchgKey, SIMPLEBLOB, 0, (BYTE *)txtKey, &length ) )
{
fprintf( stderr, "CryptExportKey() [2] failed: 0x%x\n", GetLastError() );
return( 0 );
}
length = length - (expKey->key - buffer);
if ( length > bufsiz )
{
fprintf( stderr, "Exported key requires %d bytes\n", length );
return( 0 );
}
/*
** NOTE: it appears that the encoded key is byte swapped compared
** to external standards. There is a cryptic reference to a
** ReverseMemCopy() function in the RSA/SChannel server master
** key creation example. Also, the internal RSA modulus is
** byte swapped compared to the X509 encoding. This swap is
** required to interoperate with JavaSSE.
*/
for( i = 0; i < length; i++ )
keybuf[i] = expKey->key[ length - i - 1 ];
return( length );
}
static DWORD
encode( DWORD length, BYTE *buffer, DWORD bufsiz )
{
DWORD size = length;
if ( ! CryptEncrypt( sessKey, 0, TRUE, 0, NULL, &size, bufsiz ) )
{
fprintf( stderr, "CryptEncrypt() [1] failed: 0x%x\n", GetLastError() );
return( 0 );
}
if ( size > bufsiz )
{
fprintf( stderr, "CryptEncrypt() requires %d bytes\n", size );
return( 0 );
}
if ( ! CryptEncrypt( sessKey, 0, TRUE, 0, buffer, &length, bufsiz ) )
{
fprintf( stderr, "CryptEncrypt() [2] failed: 0x%x\n", GetLastError() );
return( 0 );
}
return( length );
}
static DWORD
decode( DWORD length, BYTE *buffer, DWORD bufsiz )
{
if ( ! CryptDecrypt( sessKey, 0, TRUE, 0, buffer, &length ) )
{
fprintf( stderr, "CryptDecrypt() failed: 0x%x\n", GetLastError() );
return( 0 );
}
return( length );
}
static void
keyInfo( HCRYPTKEY key )
{
BYTE buffer[ 1024 ];
DWORD length;
length = sizeof( buffer );
if ( ! CryptGetKeyParam( key, KP_BLOCKLEN, buffer, &length, 0 ) )
{
fprintf( stderr, "CryptGetKeyParam() BLOCKLEN failed: 0x%x\n", GetLastError() );
exit( 1 );
}
printf( "Key Block Size: %d\n", *(DWORD *)buffer / 8 );
length = sizeof( buffer );
if ( ! CryptGetKeyParam( key, KP_MODE, buffer, &length, 0 ) )
{
fprintf( stderr, "CryptGetKeyParam() MODE failed: 0x%x\n", GetLastError() );
exit( 1 );
}
printf( "Key Mode: %d\n", *(DWORD *)buffer );
length = sizeof( buffer );
if ( ! CryptGetKeyParam( key, KP_PADDING, buffer, &length, 0 ) )
{
fprintf( stderr, "CryptGetKeyParam() PADDING failed: 0x%x\n", GetLastError() );
exit( 1 );
}
printf( "Key Padding Type: %d\n", *(DWORD *)buffer );
length = sizeof( buffer );
if ( ! CryptGetKeyParam( key, KP_IV, buffer, &length, 0 ) )
{
fprintf( stderr, "CryptGetKeyParam() IV failed: 0x%x\n", GetLastError() );
exit( 1 );
}
printf( "Key IV: \n" );
hexDump( length, buffer );
}
static BOOL
sendBytes( SOCKET sock, BYTE *buff, DWORD len )
{
while( len )
{
int actual = send( sock, (const char *)buff, (int)len, 0 );
if ( actual == SOCKET_ERROR )
{
fprintf( stderr, "send failed: %u\n", GetLastError() );
return( FALSE );
}
buff += actual;
len -= (DWORD)actual;
}
return( TRUE );
}
static BOOL
receiveBytes( SOCKET sock, BYTE *buff, DWORD size, DWORD *len )
{
int actual;
*len = 0;
actual = recv( sock, (char *)buff, (int)size, 0 );
if ( actual == SOCKET_ERROR )
{
fprintf( stderr, "recv failed: %u\n", GetLastError() );
return( FALSE );
}
if ( ! actual )
{
fprintf( stderr, "recv returned no data\n" );
return( FALSE );
}
*len = (DWORD)actual;
return( TRUE );
}
static void
hexDump( DWORD length, BYTE *buffer )
{
DWORD i,count,index;
CHAR digits[]="0123456789ABCDEF";
CHAR line[100];
char cnt;
for( index = 0; length; length -= count, buffer += count, index += count )
{
count = (length > 16) ? 16 : length;
sprintf_s( line, 100, "%4.4x ",index );
cnt = 6;
for( i = 0; i < count; i++ )
{
line[cnt++] = digits[buffer[i] >> 4];
line[cnt++] = digits[buffer[i] & 0x0f];
line[cnt++] = (i == 7) ? ':' : ' ';
}
for( ; i < 16; i++ )
{
line[cnt++] = ' ';
line[cnt++] = ' ';
line[cnt++] = ' ';
}
line[cnt++] = ' ';
for( i = 0; i < count; i++ )
if ( buffer[i] < 32 || buffer[i] > 126 )
line[cnt++] = '.';
else
line[cnt++] = buffer[i];
line[cnt++] = 0;
printf( "%s\n", line );
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment