/* protocol.c       Copyright (c) 2000-2002 Nagy Daniel
 *
 * $Date: 2002/06/19 14:09:35 $
 * $Revision: 1.12 $
 *
 * This module deals with incoming and outgoing packets.
 *
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version, with the
 * additional permission that it may be linked against Erick Engelke's
 * WATTCP source code and Jerry Joplin's CVT100 source code.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU Library General Public
 * License along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA.
 */

#include <stdio.h>
#include <conio.h>
#include <mem.h>

#if defined (__DJGPP__)
 #include <errno.h>
 #include <string.h>
 #include "tcp_djgp.h"
#elif defined (__TURBOC__)
 #include "zlib.h"
 #include "tcp.h"
#endif

#include "ssh.h"
#include "protocol.h"
#include "cipher.h"
#include "crc.h"
#include "xmalloc.h"
#include "macros.h"
#include "config.h"

void fatal(const char *, ...);

extern Config GlobalConfig;		/* global configuration structure */
extern unsigned short Configuration;	/* Configuration bits */

struct Packet pktin;	/* incoming SSH packet */
struct Packet pktout;	/* outgoing SSH packet */

#ifdef __TURBOC__
z_stream comp;		/* compression stream */
z_stream decomp;	/* decompression stream */
#endif

CryptPtr EncryptPacket;
CryptPtr DecryptPacket;

/*
 * Save packets in hexadecimal format for debugging purposes
 */
static void fwritehex(unsigned char *dbgbuf, unsigned short length)
{
unsigned short i;

   for(i = 1; i <= length; i++){ /* print hexa dump first */
	fprintf(GlobalConfig.debugfile, "%02X ", *dbgbuf++);
	if(i%16 == 0) /* 16 bytes per row */
	   fputs("\n", GlobalConfig.debugfile);
   }

   fputs("\n", GlobalConfig.debugfile);
   dbgbuf-=length;

   for(i = 1; i <= length; i++){ /* now print raw data */
	if(*dbgbuf >= ' ' && *dbgbuf < 126)
	   fprintf(GlobalConfig.debugfile, "%c", *dbgbuf);
	else                    /* put '.' instead of non-readable bytes */
	   fputc('.', GlobalConfig.debugfile);
	if(i%16==0)
	    fputs("\n", GlobalConfig.debugfile);
	dbgbuf++;
   }
}

#if defined (__TURBOC__)
/*
 * Compress a packet
 */
int sshcompress(const Bytef*source, uLong sourceLen,
		Bytef**dest, uLong *destLen)
{
int err;

   *dest = xmalloc(*destLen);
   comp.next_in = (Bytef*)source;	/* source buffer */
   comp.avail_in = (uInt)sourceLen;	/* source length */
   comp.next_out = *dest;		/* destination buffer */
   comp.avail_out = *destLen;		/* max destination length */

   if((err = deflate(&comp, Z_SYNC_FLUSH)) != Z_OK)
	fatal("Compression error: %d", err);

   if(comp.avail_out == 0)
	fatal("Compression buffer is too small");
   *destLen = *destLen - comp.avail_out;
   return err;
}

/*
 * Uncompress a packet
 */
int sshuncompress(const Bytef*source, uLong sourceLen,
		  Bytef **dest, uLongf *destLen)
{
int err;

   *dest = xmalloc(*destLen);
   decomp.next_in = (Bytef*)source;	/* source buffer */
   decomp.avail_in = (uInt)sourceLen;	/* source length */
   decomp.next_out = *dest;		/* destination buffer */
   decomp.avail_out = *destLen;		/* max destination length */

   if((err = inflate(&decomp, Z_PARTIAL_FLUSH)) != Z_OK)
	fatal("Decompression error: %d", err);

   if(decomp.avail_out == 0)
	fatal("Decompression buffer is too small");
   *destLen = *destLen - decomp.avail_out;
   return err;
}

/*
 * Free compression structures
 */
void Disable_Compression(void)
{
   deflateEnd(&comp);
   inflateEnd(&decomp);
}

#elif defined (__DJGPP__)

void zlib_compress_init(void);
void zlib_decompress_init(void);
int sshcompress(const unsigned char*, unsigned long,
	     unsigned char**, unsigned long*);
int sshuncompress(const unsigned char*, unsigned long,
	     unsigned char**, unsigned long*);
#endif


/*
 * Request compression
 */
void Request_Compression(int level)
{
   if(Configuration & VERBOSE_MODE)
	cputs("Requesting compression\r\n");

   s_wrpkt_start(SSH_CMSG_REQUEST_COMPRESSION, 4);
   pktout.body[0] = pktout.body[1] = pktout.body[2] = 0;
   pktout.body[3] = level; /* compression level */
   s_wrpkt();

   packet_read_type();
   switch(pktin.type){
	case SSH_SMSG_SUCCESS:
#if defined (__TURBOC__)
	   memset(&comp, 0, sizeof(comp));
	   memset(&decomp, 0, sizeof(decomp));
	   if(deflateInit(&comp, level) != Z_OK)
		fatal("Cannot initialize compression");
	   if(inflateInit(&decomp) != Z_OK)
		fatal("Cannot initialize decompression");
#elif defined (__DJGPP__)
	   zlib_compress_init();
	   zlib_decompress_init();
#endif
           Configuration |= COMPRESSION_ENABLED;
	   break;

	case SSH_SMSG_FAILURE:
	   cputs("Server refused to compress\r\n");
	   break;

	default:
	   fatal("Received invalid packet");
	} /* switch */
}


/*
 * Read and convert raw packet to readable structure.
 * Uncrypt and uncompress if necessary
 */
void ssh_gotdata(void)
{
unsigned long len;
unsigned long PktLength;/* full packet length with padding */
unsigned short pad;	/* number of padding bytes */
unsigned char PktInLength[4];	/* first four bytes of a packet (length) */
unsigned char *inbuf;		/* buffer for incoming packet */
unsigned char *decompblk;	/* buffer for decompressed data */
short i;

/*
 * 1. Get four bytes from packet, which is the packet size without padding.
 * 2. Calculate the padding length and full packet length.
 * 3. Allocate memory for the whole packet and read it.
 * 4. Uncrypt packet.
 * 5. Fill incoming packet structure
 * 6. Verify CRC (padding + type + body)
 * 7. Uncompress packet (type + body)
 * 8. Get packet type
 */

   sock_read(&GlobalConfig.s, PktInLength, 4);
   for(i = len = 0; i < 4; i++)
       len = (len << 8) + PktInLength[i];
   pad = 8 - (len%8);
   PktLength = len + pad;
   inbuf = xmalloc(PktLength);
   sock_read(&GlobalConfig.s, inbuf, PktLength); /* Read rest */

   if(Configuration & CIPHER_ENABLED) /* uncrypt */
	DecryptPacket(inbuf, PktLength);

   pktin.length = PktLength - 4 - pad;	/* minus CRC, padding */
   pktin.whole = xmalloc(pktin.length);
   memcpy(pktin.whole, inbuf + pad, pktin.length);
   xfree(inbuf); /* it's now in pktin.whole, so free it */


   if(Configuration & COMPRESSION_ENABLED){
#ifdef __TURBOC__
	len = (10 * pktin.length < MAX_PACKET_SIZE * 2) ?
			 MAX_PACKET_SIZE * 2 : 10 * pktin.length;
#endif
        sshuncompress(pktin.whole, pktin.length, &decompblk, &len);
	pktin.whole = xrealloc(pktin.whole, len);
	memcpy(pktin.whole, decompblk, len); /* copy uncompressed */
	xfree(decompblk);
	pktin.length = len;     	/* fix length after uncompression */
    } /* Compression */


   pktin.body = pktin.whole + 1; /* skip type field */
   pktin.type = *pktin.whole;    /* first byte is the type */

   if(GlobalConfig.debugfile){
	fputs("\nRECEIVED packet:\n", GlobalConfig.debugfile);
	fwritehex(pktin.whole, pktin.length);
	fputc('\n', GlobalConfig.debugfile);
   } /* debug */
}

/*
 * Calculate full packet length from body length
 */
void s_wrpkt_size(unsigned long len)
{
unsigned short pad;
unsigned long PktLength;

   len += 5;			/* add type and CRC */
   pad = 8 - (len%8);           /* calculate padding */
   PktLength = len + pad + 4;	/* add padding and length fileld for total length */

   pktout.length = len - 4;
   pktout.whole = (pktout.whole == NULL ? xmalloc(PktLength) :
		       xrealloc(pktout.whole, PktLength));
   pktout.body = pktout.whole + 5 + pad;
}

/*
 * Create header for raw outgoing packet
 */
void s_wrpkt_start(unsigned int type, unsigned int len)
{
   s_wrpkt_size(len);
   pktout.type = type;
}

/*
 * Create and send outgoing packet
 */
void s_wrpkt(void)
{
int pad, len, i;
long PktLength;
unsigned long crc;
unsigned char *compblk;
unsigned long complen;

/*
 * 1. Add type
 * 2. Compress (type + data)
 * 3. Calculate and add padding
 * 4. Calculate length (type + body + CRC)
 * 5. Calculate CRC (padding + type + body)
 * 6. Encrypt (padding + type + data + CRC)
 * 7. Send
 */

   pktout.body[-1] = pktout.type;

   if(GlobalConfig.debugfile){
	fputs("\nSENT packet:\n", GlobalConfig.debugfile);
	fwritehex(pktout.body - 1, pktout.length);
	fputc('\n', GlobalConfig.debugfile);
   }

   if(Configuration & COMPRESSION_ENABLED){
#ifdef __TURBOC__
	complen = (pktout.length + 13) * 1.1;
#endif
        sshcompress(pktout.body - 1, pktout.length, &compblk, &complen);
	s_wrpkt_size(complen - 1); /* resize packet */
	memcpy(pktout.body - 1, compblk, complen);
	xfree(compblk);
   } /* Compression */

   len = pktout.length + 4;	       /* plus CRC length */
   pad = 8 - (len%8);
   PktLength = len + pad;

   for (i = 0; i < pad; i++)		/* add padding */
	pktout.whole[i+4] = rand() % 256;

   crc = sshcrc32(pktout.whole + 4, PktLength - 4);

   PUT_32BIT_MSB_FIRST(pktout.whole + PktLength, crc);
   PUT_32BIT_MSB_FIRST(pktout.whole, len);

   if(Configuration & CIPHER_ENABLED)
	EncryptPacket(pktout.whole + 4, PktLength);

   sock_flushnext(&GlobalConfig.s);
   if(sock_write(&GlobalConfig.s, pktout.whole, PktLength + 4) != PktLength + 4)
	fatal("Socket write: %s", strerror(errno));
}

/*
 * Get a packet with blocking.
 * Handle debug, ignore and disconnect packets
 */
void packet_read_block(void)
{

restart:
   ssh_gotdata();

   switch(pktin.type){
        case SSH_MSG_DEBUG:
	   if(Configuration & VERBOSE_MODE){
                pktin.body[pktin.length - 1] = 0;
                cprintf("DEBUG: %s\n", pktin.body + 4);
	   } /* if */
	   xfree(pktin.whole);
           goto restart;

        case SSH_MSG_IGNORE:
	   xfree(pktin.whole);
	   goto restart;

        case SSH_MSG_DISCONNECT:
	   pktin.body[pktin.length - 1] = 0;
           fatal("Disconnect in packet_read_expect: %s\n", pktin.body + 4);

        default:
          break;
   } /* switch */

   return;
}

/*
 * Expect a packet type, bomb out if other arrives.
 */
void packet_read_expect(unsigned char type)
{
   packet_read_block();
   if(pktin.type != type)
	fatal("Received invalid packet");
}

/*
 * We need the type only, so free the body
 */
void packet_read_type(void)
{
   packet_read_block();
   xfree(pktin.whole);
}

/*
 * Send n data bytes as an SSH packet
 */
void SendSSHPacket(unsigned char *buff, unsigned short len)
{
   s_wrpkt_start(SSH_CMSG_STDIN_DATA, len + 4);
   pktout.body[0] = pktout.body[1] = 0; /* max length is only 16 bits */
   pktout.body[2] = len >> 8;
   pktout.body[3] = len & 0xff;
   memcpy(pktout.body + 4, buff, len);
   s_wrpkt();
}
