#include <stdio.h>
#include <stdlib.h>

#include <time.h>
#include <cmath>
#include <mpi.h>
#include "timer.h"

#define SIZE 100000000

int array[SIZE];

int number;
int counter;

// Randomize the data array with values in [0, 10)
void fill ( )
{
	srand (time (NULL));

	for (int i=0; i<SIZE; i++)
	{
		array[i] = rand()%10;
	}
}

// Count the number of 3's between start and end
void count (int start, int end)
{
    	int i;

    	for (i=start; i<start+end; i++)
    	{
		if (array[i] == 3)
        	{
			counter ++;
        	}
    	}
}

int main (int argc, char** argv)
{
    	int i, j, k;
    	int overall;

    	int myID, numProcs, size, rem;
    	MPI_Status status;

	timer watch;

    	float total=0.0f;

	// Set up MPI and discover a few things about the work group

	MPI_Init (&argc, &argv);
    	MPI_Comm_size (MPI_COMM_WORLD, &numProcs);

    	MPI_Comm_rank (MPI_COMM_WORLD, &myID);

    	number = numProcs;

	// Set up the partitions, and find any remainder
	size = SIZE/numProcs;
    	rem = SIZE%numProcs;

	// Do this stuff 100 times
	for (j=0; j<100; j++)
    	{
        	overall = 0;

		// Root node
		if (myID == 0)
		{
			// Fill the array
            		fill ( );

			// Broadcast only part of it
			for (i=1; i < numProcs; i++)
			{
				MPI_Send (array+(i*size), size, MPI_INT, i, i, MPI_COMM_WORLD);
            		}
        	}

        	else
        	{
			// Index into the array and store your part
       	     		MPI_Recv (array+(myID*size), size, MPI_INT, 0, myID, MPI_COMM_WORLD, &status);
        	}

		// Everyone MUST be here before computations are started
        	MPI_Barrier (MPI_COMM_WORLD);

		counter = 0;

		// Root node
		if (myID == 0)
		{
			// Only guy to do the timing
			watch.start ();

			// Takes the first partition, and anything left over
			count (0, size);
			count (numProcs*size, rem);
		}

		// Everyone else
		else
		{
			// Just compute your part
			count (myID*size, size);
		}

		// Combine everything together
		MPI_Reduce (&counter, &counter, 1, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);

		// Root Node
		if (myID == 0)
		{
			// Stop the watch
			watch.stop ();

			// Keep only the 80 center sets
			if (j>10 && j<90)
			{
				total += watch.read ();
			}
		}
	}

	// Print the average
	if (myID == 0)
	{
        	printf ("%i\t%.3f\n", number, total/80.0f);
	}

	// Done!
	MPI_Finalize ();

	return 0;
}