#include <stdio.h>
#include <mpi.h>

#include <stdlib.h>
#include <time.h>
#include "timer.h"
#include <cmath>

// Number of threads to create
int number;

#define size 1000

// The arrays
int a[size][size], b[size][size], c[size][size];

/*-----------------------------------------------------------------------------
 Function name: fill
 -------------------------------------------------------------------------------
 Preconditions:
	-None
 Postconditions:
	-None
 Algorithm:
	-Put a value into every slot in all 3 arrays
 Exception/Error Handlig:
	-None
 -----------------------------------------------------------------------------*/
void fill ( )
{
	srand (time (NULL));

	for (int i=0; i<size; i++)
	{
		for (int j=0; j<size; j++)
		{
			a[i][j] = rand ()%100;
			b[i][j] = rand ()%100;
			c[i][j] = 0;
		}
	}
}

// Multiply from the indicated start to the indicated end
// Works on rows
void mult (int start, int interval)
{
	int i, j, k;

	for (i=start; i<start+interval; i++)
	{
    		for (j=0; j<size; j++)
        	{
            		c[i][j] = 0;

            		for (k=0; k<size; k++)
            		{
                		c[i][j] += a[i][k]*b[k][j];
            		}
        	}
    	}
}

/*-----------------------------------------------------------------------------
 Function name: main
 -------------------------------------------------------------------------------
 Preconditions:
	-None
 Postconditions:
	-All dynamic memory is cleaned up
	-Correct output isnatural point provided
 Algorithm:
	-Set all variables to a default
	-Start the stopwatch
	-Create the appropriate number of threads
	-Stop the stopwatch after all threads have completed
	-Do output
 Exception/Error Handlig:
	-None
 -----------------------------------------------------------------------------*/
int main (int argc, char **argv)
{
	int i, j, k;
	timer watch;
	double time;

	int interval, remainder;;
	int myID, numProcs;

	MPI_Status status;

	MPI_Init (&argc, &argv);

	// Get MPI info
	MPI_Comm_size (MPI_COMM_WORLD, &numProcs);
	MPI_Comm_rank (MPI_COMM_WORLD, &myID);

	// Set number
	number = numProcs;
	time = 0.0f;

	// Get interval, and any extra needed for the base thread
	interval = size/numProcs;
	remainder = size%numProcs;

	// Initialie the total time
	time = 0.0f;

	// 100 iterations of matrix mult
	for (j=0; j<100; j++)
	{
		// Only do this if you are root
		if (myID == 0)
		{
			// Fill the arrays
			fill ();

			// Broadcast all of array B to the other nodes
			MPI_Bcast (b, size*size, MPI_INT, 0, MPI_COMM_WORLD);

			// Send the individual parts out
			for (i=1; i<numProcs; i++)
			{
				MPI_Send (a+i*interval, interval*size, MPI_INT, i, i, MPI_COMM_WORLD);
			}
		}

		else
		{
			// Receive the broadcast
			MPI_Bcast (b, size*size, MPI_INT, 0, MPI_COMM_WORLD);

			// Receive your data
			MPI_Recv (a+(myID*interval), interval*size, MPI_INT, 0, myID, MPI_COMM_WORLD, &status);
		}

		// Barrier to make all nodes start at the same place
		MPI_Barrier (MPI_COMM_WORLD);

		// Root only
		if (myID == 0)
		{
			// Start the timer
			watch.start ();
            
			// Multiply the matrices for yourself
			mult (myID*interval, interval);
			mult (numProcs*interval, remainder);

			// Receive the results
			for (i=1; i<numProcs; i++)
			{
				MPI_Recv (c+(i*interval), interval*size, MPI_INT, i, i, MPI_COMM_WORLD, &status);
			}
		}

		// Child nodes
		else
		{
			// Multiply your part
			mult (myID*interval, interval);

			// Send your results out
			MPI_Send (c+(myID*interval), interval*size, MPI_INT, 0, myID, MPI_COMM_WORLD);
		}

		// Catch everyone at the same place
		MPI_Barrier (MPI_COMM_WORLD);

		// Root node only
		if (myID == 0)
		{
			// Stop the watch
			watch.stop ();

			// If you are part of the middle section keep the results
			if (j > 10 && j < 90)
			{
				time += watch.read ();
			}
		}
	}

	// Print the results
	if (myID == 0)
		printf("%i\t%.3f\n", number, time/80.0f);
    
	// Finalize the MPI stuff
	MPI_Finalize ();

	return 0;
}