/* 
  Copyright (C) 2008 Kai Hertel, André Gaul

	This file is part of mmpong.

	mmpong is free software: you can redistribute it and/or modify
	it under the terms of the GNU General Public License as published by
	the Free Software Foundation, either version 3 of the License, or
	(at your option) any later version.

	mmpong is distributed in the hope that it will be useful,
	but WITHOUT ANY WARRANTY; without even the implied warranty of
	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
	GNU General Public License for more details.

	You should have received a copy of the GNU General Public License
	along with mmpong.  If not, see <http://www.gnu.org/licenses/>.
*/

#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <limits.h>
#include <errno.h>
#include "message.h"

#ifdef WIN32
# include <winsock2.h>
//typedef int (*iofunc)(SOCKET, char *, int, int);
#else
# include <arpa/inet.h>
# include <sys/socket.h>
//typedef int (*iofunc)(int, void *, int, int);
#endif

typedef typeof(recv) *iofunc;

#define OFFSET(base, member) ((int)( ((void *)&(member)) - ((void *)&(base)) ))
#define INSTANCE(var, offset) ( ((void *)&(var)) + (offset) )
#define REGISTER(mode, member) { mode, OFFSET(nmsgbuf, nmsgbuf.member), sizeof(nmsgbuf.member) }
#define ENDMARKER(mode) { mode, (-1), 0 }
// the following one is to address a nasty compiler bug in 4.2 series GCCs
#define REGISTER_ARR(mode, member, idx) { mode, OFFSET(nmsgbuf, nmsgbuf.member[0]) + idx*sizeof(nmsgbuf.member[idx]), sizeof(nmsgbuf.member[idx]) }

// this one mainly exists for making the constant table below work elegantly
// (in order to allow for the access macros above to have a reference to base calculations on)
// the sendmessage() function could be easily changed to use a local instance instead of this TLS one, ... and it has
static struct netmessage nmsgbuf;
static struct timeval scrambler_salt= { .tv_sec= 0 };

static const struct {
	const enum netmessage_type id; /* 0 == generic, (-1) == terminal */
	const int offset;
	const int size;
} byte_order[]= { 	// seems to be getting a little out of hand...
	REGISTER (0, hdr.id),
	REGISTER (0, hdr.len),
	REGISTER (0, hdr.stamp.tv_sec),
	REGISTER (0, hdr.stamp.tv_usec),
	REGISTER (NETMSG_POS, payload.position),
	ENDMARKER (NETMSG_POS),
	REGISTER (NETMSG_UPDT, payload.part.stamp.tv_sec),
	REGISTER (NETMSG_UPDT, payload.part.stamp.tv_usec),
	REGISTER_ARR (NETMSG_UPDT, payload.part.ball.pos, 0),
	REGISTER_ARR (NETMSG_UPDT, payload.part.ball.pos, 1),
	REGISTER_ARR (NETMSG_UPDT, payload.part.ball.dir, 0),
	REGISTER_ARR (NETMSG_UPDT, payload.part.ball.dir, 1),
	REGISTER (NETMSG_UPDT, payload.part.pad[0].mean),
	REGISTER (NETMSG_UPDT, payload.part.pad[0].var),
	REGISTER (NETMSG_UPDT, payload.part.pad[0].size),
//	REGISTER (NETMSG_UPDT, payload.part.pad[0].dir),
	REGISTER (NETMSG_UPDT, payload.part.pad[1].mean),
	REGISTER (NETMSG_UPDT, payload.part.pad[1].var),
	REGISTER (NETMSG_UPDT, payload.part.pad[1].size),
//	REGISTER (NETMSG_UPDT, payload.part.pad[1].dir),
	ENDMARKER (NETMSG_UPDT),
	REGISTER (NETMSG_STAT, payload.full.team),
	REGISTER (NETMSG_STAT, payload.full.game.version),
	REGISTER (NETMSG_STAT, payload.full.game.stamp.tv_sec),
	REGISTER (NETMSG_STAT, payload.full.game.stamp.tv_usec),
	REGISTER (NETMSG_STAT, payload.full.game.mode),
	REGISTER (NETMSG_STAT, payload.full.game.status),
	REGISTER (NETMSG_STAT, payload.full.game.pad_attr[0].score),
	REGISTER (NETMSG_STAT, payload.full.game.pad_attr[0].peers),
	REGISTER (NETMSG_STAT, payload.full.game.pad_attr[0].profile),
	REGISTER (NETMSG_STAT, payload.full.game.pad_attr[1].score),
	REGISTER (NETMSG_STAT, payload.full.game.pad_attr[1].peers),
	REGISTER (NETMSG_STAT, payload.full.game.pad_attr[1].profile),
	REGISTER_ARR (NETMSG_STAT, payload.full.game.ball.pos, 0),
	REGISTER_ARR (NETMSG_STAT, payload.full.game.ball.pos, 1),
	REGISTER_ARR (NETMSG_STAT, payload.full.game.ball.dir, 0),
	REGISTER_ARR (NETMSG_STAT, payload.full.game.ball.dir, 1),
	REGISTER (NETMSG_STAT, payload.full.game.pad[0].mean),
	REGISTER (NETMSG_STAT, payload.full.game.pad[0].var),
	REGISTER (NETMSG_STAT, payload.full.game.pad[0].size),
//	REGISTER (NETMSG_STAT, payload.full.game.pad[0].dir),
	REGISTER (NETMSG_STAT, payload.full.game.pad[1].mean),
	REGISTER (NETMSG_STAT, payload.full.game.pad[1].var),
	REGISTER (NETMSG_STAT, payload.full.game.pad[1].size),
//	REGISTER (NETMSG_STAT, payload.full.game.pad[1].dir),
	ENDMARKER(-1)
}, scramble_fields[]= {
	REGISTER (0, hdr.stamp),
	REGISTER (NETMSG_UPDT, payload.part.stamp),
	REGISTER (NETMSG_STAT, payload.full.game.stamp),
	ENDMARKER(-1)
};


static inline int io_raw(const int, void *, uint16_t *, const uint16_t, iofunc);
static inline int apply_byte_order(uint16_t, struct netmessage *, int (*)(void *, const int));
static inline int convert_to_network(void *, const int);
static inline int convert_to_host(void *, const int);
static inline int scramble_stamps(const struct timeval *, struct netmessage *);



EXPORT int netmessage_send(sock, msgid, data, datalen, peerbuf)
const int sock;
const enum netmessage_type msgid;
const void *data;
const uint16_t datalen;
struct netmessage_buffer *peerbuf;
{
	struct netmessage stage;

	int retcode= netmessage_buffer_flush(sock, peerbuf);
	if ((retcode == NETMSG_FAIL_SOCKET) || (retcode == NETMSG_END_SOCKET))
		return retcode; 	// these conditions are considered critical and should take precedence over other issues
	if ((datalen > sizeof(stage.payload)) || (datalen && !data))
		return NETMSG_ARGINVALID;
	if ((retcode != NETMSG_SUCCESS) && (retcode != NETMSG_ARGINVALID)) {
		if (retcode == NETMSG_PARTIAL) return NETMSG_FAIL_DELIVER;
		return retcode;
	}
	// else: flushed or nothing to flush

	// assemble header and fill in payload
	stage.hdr.id= msgid;
	uint16_t reqlen= stage.hdr.len= sizeof(stage.hdr) + datalen;

	struct timeval stamp;
	if (gettimeofday(&stamp, NULL)) {
		perror("gettimeofday()");
		return NETMSG_FAIL_SYSCRITICAL;
	}
	stage.hdr.stamp.tv_sec = (uint32_t)stamp.tv_sec;
	stage.hdr.stamp.tv_usec = (uint32_t)stamp.tv_usec;
	if (datalen)
		memcpy(&stage.payload, data, reqlen - sizeof(stage.hdr));

	// prepare and try sending out data
	if (scrambler_salt.tv_sec >0)
		if (scramble_stamps(&scrambler_salt, &stage) != NETMSG_SUCCESS) {
			fprintf(stderr, "Error: Cannot scramble time stamps.\n");
			return NETMSG_FAIL_CONVERT;
		}
	if (apply_byte_order(msgid, &stage, convert_to_network) != NETMSG_SUCCESS) {
		fprintf(stderr, "Error: Cannot convert data to network byte order.\n");
		return NETMSG_FAIL_CONVERT;
	}
	uint16_t len= 0;
	retcode= io_raw(sock, &stage, &len, reqlen, (iofunc)send);
	if ((retcode == NETMSG_FAIL_SOCKET) || (retcode == NETMSG_END_SOCKET))
		return retcode;

	// make sure messages won't get cut off when the send buffer fills up
	if ((len != reqlen) && (peerbuf) && (peerbuf->sz >= reqlen)) {
		memcpy( ((char *)(&peerbuf->msg)) + len, ((char *)(&stage)) + len, reqlen - len );
		peerbuf->pos= len;
		peerbuf->len= reqlen;
		return NETMSG_PARTIAL;
	}
	return retcode;
}



EXPORT int netmessage_recv(sock, rmsg, rmsglen, peerbuf)
const int sock;
struct netmessage *rmsg;
const uint16_t rmsglen;
struct netmessage_buffer *peerbuf;
{
	if (sizeof(struct netmessage) > SSIZE_MAX)
		return NETMSG_FAIL_SYSCRITICAL;
	if (!rmsg || (sizeof(rmsg->hdr) > rmsglen))
		return NETMSG_ARGINVALID;

	uint16_t pos= 0, *ppos= &pos;
	const uint16_t *plen= &rmsglen;
	struct netmessage *pmsg= rmsg;
	if (peerbuf) {
		if (sizeof(rmsg->hdr) > peerbuf->sz)
			return NETMSG_ARGINVALID;
		ppos= &peerbuf->pos;
		plen= &peerbuf->sz;
		pmsg= &peerbuf->msg;
	}

	// process the fixed-size header first
	int retcode;
	if ( (*ppos < sizeof(pmsg->hdr)) &&
		 ((retcode= io_raw(sock, pmsg, ppos, sizeof(pmsg->hdr), (iofunc)recv)) != NETMSG_SUCCESS) )
			return retcode;

	// validate and process the variable-size payload
	uint16_t reqlen= ntohs(pmsg->hdr.len); 	// special case
	if ( (reqlen < sizeof(pmsg->hdr)) || (reqlen > sizeof(struct netmessage)) )
		return NETMSG_FAIL_CHECKSUM; 	// packet bounds checking
	if (reqlen > *plen)
		return NETMSG_ARGINVALID;

	retcode= io_raw(sock, pmsg, ppos, reqlen, (iofunc)recv);
	if (retcode == NETMSG_FAIL_DELIVER) return NETMSG_PARTIAL;
	if (retcode != NETMSG_SUCCESS) return retcode;
	if (*ppos != reqlen)
		return NETMSG_FAIL_CHECKSUM;

	if (peerbuf) {
		memcpy(rmsg, pmsg, *ppos);
		*ppos= 0;
		pmsg= rmsg;
	}
	if (apply_byte_order(ntohs(pmsg->hdr.id), pmsg, convert_to_host) != NETMSG_SUCCESS) {
		fprintf(stderr, "Error: Cannot convert data to host byte order.\n");
		return NETMSG_FAIL_CONVERT;
	}
	return NETMSG_SUCCESS;
}



// take care of bunched-up messages
EXPORT int netmessage_buffer_flush(sock, peerbuf)
const int sock;
struct netmessage_buffer *peerbuf;
{
	if (!peerbuf) return NETMSG_ARGINVALID;
	if (peerbuf->pos >= peerbuf->len) return NETMSG_SUCCESS;
	int retcode= io_raw(sock, &peerbuf->msg, &peerbuf->pos, peerbuf->len, (iofunc)send);
	if (peerbuf->pos == peerbuf->len) peerbuf->pos= peerbuf->len= 0; 	// clean up, optional at this point
	return retcode;
}



static inline int io_raw(sock, buf, pos, len, iofct)
const int sock;
void *buf;
uint16_t *pos;
const uint16_t len;
iofunc iofct;
{

	if (!buf || !pos || *pos > len || !iofct)
		return NETMSG_ARGINVALID;

	int loc, org= *pos;
	while (*pos < len) {
		errno= 0;
		loc= iofct(sock, ((char *)buf) + *pos, len - *pos, 0);
		if (errno && (errno != EINTR) && (errno != EAGAIN)) {
			// serious error
			perror("write()");
			return NETMSG_FAIL_SOCKET;
		}

		if (!loc) 	// usually EOF
			return NETMSG_END_SOCKET;
		if (loc> 0)
			*pos += loc;
#ifdef WIN32
		if (WSAGetLastError() == WSAEWOULDBLOCK)
			return (org != *pos)? NETMSG_PARTIAL : NETMSG_FAIL_DELIVER;
#else
		if (errno == EAGAIN) 	// message has not transferred completely
			return (org != *pos)? NETMSG_PARTIAL : NETMSG_FAIL_DELIVER;
#endif
	}

	return NETMSG_SUCCESS;
}



static inline int apply_byte_order(msgid, pmsg, convert)
uint16_t msgid;
struct netmessage *pmsg;
int (*convert)(void *, const int);
{
	// take care of network byte order for specific message types
	for (int idx= 0; byte_order[idx].id != ((enum netmessage_type)(-1)); idx++)
		if ((msgid == byte_order[idx].id) || (0 == byte_order[idx].id)) {
			if (byte_order[idx].offset == (-1)) break; 	// shortcut to reduce average loop times
			if (convert( INSTANCE(*pmsg, byte_order[idx].offset), byte_order[idx].size ) != NETMSG_SUCCESS) {
				fprintf(stderr, "Error: Cannot convert data to host byte order.\n");
				return NETMSG_FAIL_CONVERT;
			}
		}
	return NETMSG_SUCCESS;
}



static inline int convert_to_network(data, size)
void *data;
int size;
{
	switch(size) {
	case 1: 	// just in case
		break;
	case sizeof(uint16_t):
		*((uint16_t *)data)= htons( *((uint16_t *)data) );
		break;
	case sizeof(uint32_t):
		*((uint32_t *)data)= htonl( *((uint32_t *)data) );
		break;
	default:
		return NETMSG_FAIL_CONVERT;
	}
	return NETMSG_SUCCESS;
}



static inline int convert_to_host(data, size)
void *data;
int size;
{
	switch(size) {
	case 1: 	// just in case
		break;
	case sizeof(uint16_t):
		*((uint16_t *)data)= ntohs( *((uint16_t *)data) );
		break;
	case sizeof(uint32_t):
		*((uint32_t *)data)= ntohl( *((uint32_t *)data) );
		break;
	default:
		return NETMSG_FAIL_CONVERT;
	}
	return NETMSG_SUCCESS;
}



EXPORT int netmessage_buffer_init(bufptr)
struct netmessage_buffer **bufptr;
{
	if ((!bufptr) || (*bufptr)) return NETMSG_ARGINVALID;
	*bufptr= calloc(1, sizeof(struct netmessage_buffer));
	if (! *bufptr)
		return NETMSG_FAIL_SYSCRITICAL;
	(*bufptr)->sz= sizeof(struct netmessage);
	return NETMSG_SUCCESS;
}



EXPORT int netmessage_scrambler_init(void)
{
	if (gettimeofday(&scrambler_salt, NULL)) {
		perror("gettimeofday()");
		return NETMSG_FAIL_SYSCRITICAL;
	}
	if (!scrambler_salt.tv_sec)
		return NETMSG_FAIL_CHECKSUM;
	scrambler_salt.tv_sec= rand() % scrambler_salt.tv_sec;
	scrambler_salt.tv_usec= rand() % (1000L * 1000L);
	return NETMSG_SUCCESS;
}



static inline int scramble_stamps(salt, msg)
const struct timeval *salt;
struct netmessage *msg;
{
	enum netmessage_type msgid= msg->hdr.id;
	for (int idx= 0; scramble_fields[idx].id != ((enum netmessage_type)(-1)); idx++)
		if ((msgid == scramble_fields[idx].id) || (0 == scramble_fields[idx].id)) {
//			if (scramble_fields[idx].offset == (-1)) break;
			struct gametime_public *val= INSTANCE(*msg, scramble_fields[idx].offset);
			if (val->tv_sec < salt->tv_sec)
				return NETMSG_FAIL_CONVERT;
			val->tv_sec-= salt->tv_sec;
			val->tv_usec-= salt->tv_usec;
			if (val->tv_usec < 0) {
				if (val->tv_sec < 1) {
					val->tv_usec= 0;
					return NETMSG_FAIL_CONVERT;
				}
				val->tv_sec--;
				val->tv_usec+= 1000L * 1000L;
			}
		}
	return NETMSG_SUCCESS;
}



EXPORT int netmessage_get_hdr_stamp(msg, stamp)
const struct netmessage *msg;
struct timeval *stamp;
{
	if ((!msg) || (!stamp))
		return NETMSG_ARGINVALID;
	stamp->tv_sec= msg->hdr.stamp.tv_sec;
	stamp->tv_usec= msg->hdr.stamp.tv_usec;
	return NETMSG_SUCCESS;
}

