#include <stdio.h>
#include <pthread.h>

#include <stdlib.h>
#include <time.h>
#include "timer.h"
#include <cmath>

// Number of threads to create
int number;

#define size 1000

// The arrays
int **a, **b, **c;

/*-----------------------------------------------------------------------------
 Function name: fill
 -------------------------------------------------------------------------------
 Preconditions:
	-None
 Postconditions:
	-None
 Algorithm:
	-Put a value into every slot in all 3 arrays
 Exception/Error Handlig:
	-None
 -----------------------------------------------------------------------------*/
void fill ( )
{

	srand (time (NULL));

	for (int i=0; i<size; i++)
	{
		for (int j=0; j<size; j++)
		{
			a[i][j] = rand ()%100;
			b[i][j] = rand ()%100;
			c[i][j] = 0;
		}
	}
}

/*-----------------------------------------------------------------------------
 Function name: matMultTH
 -------------------------------------------------------------------------------
 Preconditions:
	-a, b and c are of the correct dimensions
 Postconditions:
	-c contains the multiplication of a x b
 Algorithm:
	-Determine which portion of the matrices the particular thread operates on
	-In row major order solve each slot
 Exception/Error Handlig:
	-No locking is done because of the size of the arrays: no 2 cores should be
		falsely sharing
 -----------------------------------------------------------------------------*/
void *matMultTH (void* args)
{

	size_t trueid;
	int start, l_per_thread, end;

	// Get your ID
	trueid = reinterpret_cast <size_t> (args);
	
	// Get your partition
	l_per_thread = (int)ceil((float)size/(float)number);

	start = l_per_thread * trueid;
	end = start + l_per_thread;

	if (end > size)
	{
		end = size;
	}

	// Crunch some numbers
	for (int i=start; i<end; i++)
	{
		for (int j=0; j<size; j++)
		{
			for (int k=0; k<size; k++)
			{
				c[i][j] += a[i][k] * b[k][j];
			}
		}
	}

	pthread_exit (NULL);
}

/*-----------------------------------------------------------------------------
 Function name: matMult
 -------------------------------------------------------------------------------
 Preconditions:
	-a, b and c are of the correct dimensions
 Postconditions:
	-c contains the multiplication of a x b
 Algorithm:
	-Perform 3 imbedded loop matrix multiplication
 Exception/Error Handlig:
	-None
 -----------------------------------------------------------------------------*/
void matMult ()
{

	for (int i=0; i<size; i++)
	{
		for (int j=0; j<size; j++)
		{
			for (int k=0; k<size; k++)
			{
				c[i][j] += a[i][k] * b[k][j];
			}
		}
	}
}

/*-----------------------------------------------------------------------------
 Function name: main
 -------------------------------------------------------------------------------
 Preconditions:
	-None
 Postconditions:
	-All dynamic memory is cleaned up
	-Correct output isnatural point provided
 Algorithm:
	-Set all variables to a default
	-Start the stopwatch
	-Create the appropriate number of threads
	-Stop the stopwatch after all threads have completed
	-Do output
 Exception/Error Handlig:
	-None
 -----------------------------------------------------------------------------*/
int main (int argc, char **argv)
{

    	int i, j, k;
    	pthread_t workers[16];

    	timer watch;
    	double time;

   	number = i = 0;

    	time = 0.0f;

	// Allocate memory
    	a = new int * [size];
    	b = new int * [size];
    	c = new int * [size];

    	for (i=0; i<size; i++)
    	{
        	a[i] = new int [size];
        	b[i] = new int [size];
        	c[i] = new int [size];
    	}

	// Run and time sequential
    	for (j=0; j<100; j++)
    	{
		fill ();

        	watch.start ( );

        	matMult ( );

        	watch.stop ( );

        	if (j > 10 && j < 90)
        	{
       			time += watch.read ();
       		}
	}

	printf ("%i\t%.3f\n", number, time/80.0f);
	
	// Run and time parallel
	for (i=1; i<17; i++)
	{
		number = i;
        	time = 0.0f;

        	for (j=0; j<100; j++)
        	{
        		fill ();

	        	watch.start ();

	        	for (k=0; k<number; k++)
	        	{
	    		    pthread_create (&workers[k], NULL, matMultTH, (void *)k);
	        	}

	        	for (k=0; k<number; k++)
	       		{
				pthread_join (workers[k], NULL);
     	    		}

    	    		watch.stop ();

            		if (j > 10 && j < 90)
            		{
                		time += watch.read ();
            		}
        	}

		printf("%i\t%.3f\n", number, time/80.0f);

    	}
    
	// Cleanup
	for (i=0; i<size; i++)
	{
	        delete [] a[i];
	        delete [] b[i];
        	delete [] c[i];
    	}
        
	delete [] a;
    	delete [] b;
    	delete [] c;
}