/*
 * Copyright (C) 2009, 2010 Arnaldo Carvalho de Melo <acme@redhat.com>
 * Licensed under the GPLv2
 */
#include <stdlib.h>
#include <syscall.h>
#include <stdio.h>
#include <sys/socket.h>
#include <unistd.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <poll.h>
#include <string.h>

static int verbose;

struct mmsghdr {
	struct msghdr	msg_hdr;
	unsigned	msg_len;
};

#if defined(__x86_64__) || defined(__i386__)
#include "linux-2.6-tip/arch/x86/include/asm/unistd.h"
#ifndef __NR_recvmmsg
#ifdef __x86_64__
#define __NR_recvmmsg 299
#elif defined(__i386__)
#define __NR_recvmmsg 337
#endif
#endif
#endif

#ifndef NSEC_PER_MSEC
#define NSEC_PER_MSEC	1000000UL
#endif

static inline int recvmmsg(int fd, struct mmsghdr *mmsg,
			   unsigned vlen, unsigned flags,
			   struct timespec *timeout)
{
	return syscall(__NR_recvmmsg, fd, mmsg, vlen, flags, timeout);
}

static void print_stats_peer(struct mmsghdr *datagram, int count, int bytes)
{
	char peer[1024];
	int err = getnameinfo(datagram->msg_hdr.msg_name,
			      datagram->msg_hdr.msg_namelen,
			      peer, sizeof(peer), NULL, 0, 0);
	if (err != 0) {
		fprintf(stderr, "error using getnameinfo: %s\n",
			gai_strerror(err));
			return;
		}
	printf("    %d bytes received from %s in %d datagrams\n",
	       bytes, peer, count);
}

int main(int argc, char *argv[])
{
	struct addrinfo *host;
	struct addrinfo hints = {
		.ai_family   = AF_INET,
		.ai_socktype = SOCK_DGRAM,
		.ai_protocol = IPPROTO_UDP,
		.ai_flags    = AI_PASSIVE,
	};
	const char *port = "5001";
	int batch_size = 8;
	long timeout = 10 * NSEC_PER_MSEC;
	int err, fd;
	int i;
	int use_recvmsg = 0;
	unsigned long long nr_syscalls = 0;
	unsigned long long max_datagrams = 0;
	unsigned long long total_datagrams_received = 0;

	if (argc > 1)
		port = argv[1];

	if (argc > 2)
		batch_size = atoi(argv[2]);

	if (argc > 3)
		timeout = atol(argv[3]) * NSEC_PER_MSEC;

	if (argc > 4)
		max_datagrams = atoi(argv[4]);

	if (argc > 5)
		use_recvmsg = strcmp(argv[5], "recvmsg") == 0;

	char buf[batch_size][256];
	struct iovec iovec[batch_size][1];
	struct sockaddr addr[batch_size];
	struct mmsghdr datagrams[batch_size];

	printf("usage: recvmmsg <port(def 5001)> <batch_size(def 8)> "
	       "<timeout(def 10ms)> <max_datagrams(def no max)> [<recvmsg>]"
	       "\n\nPress ENTER to exit\n\nWaiting for datagrams...\n");

	err = getaddrinfo(NULL, port, &hints, &host);
	if (err != 0) {
		fprintf(stderr, "error using getaddrinfo: %s\n",
			gai_strerror(err));
		goto out;
	}
	
	fd = socket(host->ai_family, host->ai_socktype, host->ai_protocol);
	if (fd < 0) {
		perror("socket: ");
		goto out_freeaddrinfo;
	}

	if (bind(fd, host->ai_addr, host->ai_addrlen) < 0) {
		perror("bind: ");
		goto out_close_server;
	}

	for (i = 0; i < batch_size; ++i) {
		iovec[i][0].iov_base = buf[i];
		iovec[i][0].iov_len  = sizeof(buf[i]);
		datagrams[i].msg_hdr.msg_iov	 = iovec[i];
		datagrams[i].msg_hdr.msg_iovlen	 = 1;
		datagrams[i].msg_hdr.msg_name	 = &addr[i];
		datagrams[i].msg_hdr.msg_namelen = sizeof(addr[i]);
	}

	struct pollfd pfds[2] = {
		[0] = {
			.fd = 0,
			.events = POLLIN,
		},
		[1] = {
			.fd = fd,
			.events = POLLIN,
		},
	};

	while (1) {
		struct timespec timeout = { .tv_nsec = 10 * NSEC_PER_MSEC, };

		if (poll(pfds, 2, -1) < 0) {
			perror("poll: ");
			return EXIT_FAILURE;
		}

		if (pfds[0].revents)
			break;

		int nr_datagrams;

		if (use_recvmsg) {
			nr_datagrams = 1;
			if (recvmsg(fd, &datagrams[i].msg_hdr, 0) < 0)
				nr_datagrams = 0;
		} else
			nr_datagrams = recvmmsg(fd, datagrams, batch_size, 0, &timeout);

		if (nr_datagrams == 0) {
			perror(use_recvmsg ? "recvmsg: " : "recvmmsg: ");
			return EXIT_FAILURE;
		}

		++nr_syscalls;

		if (verbose)
			printf("nr_datagrams received: %d, remaining: %luns\n",
			       nr_datagrams, timeout.tv_nsec);
		else if (nr_syscalls == 1)
			printf("received first datagram, suppressing the others\n");
			
		int peer_count = 1;
		int peer_bytes = datagrams[0].msg_len;
		for (i = 1; i < nr_datagrams; ++i) {
			if (memcmp(datagrams[i - 1].msg_hdr.msg_name,
				   datagrams[i].msg_hdr.msg_name,
				   datagrams[i].msg_hdr.msg_namelen) == 0) {
				++peer_count;
				peer_bytes += datagrams[i].msg_len;
				continue;
			}
			
			if (verbose)
				print_stats_peer(&datagrams[i - 1],
						 peer_count, peer_bytes);
			peer_bytes = datagrams[i].msg_len;
			peer_count = 1;
		}
		if (verbose)
			print_stats_peer(&datagrams[nr_datagrams - 1],
					 peer_count, peer_bytes);

		if (max_datagrams) {
			total_datagrams_received += nr_datagrams;
			if (total_datagrams_received >= max_datagrams)
				break;
		}
	}

	printf("# recvmmsg calls: %Ld\n", nr_syscalls);
out_close_server:
	close(fd);
out_freeaddrinfo:
	freeaddrinfo(host);
out:
	return err;
}
