#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <malloc.h>

#define DMATRIX
#include <mysetting.h>

void DMatPrint( double *a[], int row, int colum )	/* s̕\@*/
{
	int i;
	int j;
	
	for( i = 0; i < row; i++ )
	{
		for( j = 0; j < colum; j++ )
		{
			printf( "%f ", a[i][j] );	
		}
		printf( "\n" );
	}

}

void DMatTrans( double *a[], double *result[], int row, int colum )	/* s̓]usvZ */
{
	int i, j;
	
	for( i = 0; i < row; i++ )
	{
		for(j = 0; j < colum; j++ )
		{
			result[j][i] = a[i][j];	
		}
	}
}

int DMatInv( double *a[], double *result[], int row, int colum )	/* tsvZ */
{
	double **temp;
	double pivot;
	double pivot2;
	int i, j, l;
	
	temp = (double **)malloc( sizeof(double *)*row );
	for( i = 0; i < row; i++ )
	{
		temp[i] = (double *)malloc( sizeof(double)*(colum*2));
		for( j = 0; j < colum*2; j++ )
			temp[i][j] = 0.0;
/*		memset( temp[i], 0, sizeof(double)*(colum*2));*/
	}
	for( i = 0; i < row; i++ )
	{
		for( j = 0; j < colum; j++ )
		{
			temp[i][j] = a[i][j];	
		}
		temp[i][colum+i] = 1.0;
	}
	for( i = 0; i < row; i++ )
	{
		pivot = temp[i][i];
		if( pivot == 0.0 )
			return(0);

		for( j = 0; j < colum*2; j++ )
		{
			temp[i][j] /= pivot;
		}
		for( l = 0; l < row; l++ )
		{
			if( l == i )
				continue;
			pivot2 = temp[l][i];
			for( j = 0; j < colum*2; j++ )
			{
				temp[l][j] = temp[l][j] - pivot2 * temp[i][j];  	
			}
		}
	}
	for( i = 0; i < row; i++ )
	{
		for( j = 0; j < colum; j++ )
		{
			result[i][j] = temp[i][colum+j];	
		}
	}
	return( 1 );
}

void DMatAdd( double *a[], double *b[], double *result[], int row, 
	int colum )							/* s̑Z */
{
	int i, j;

	for( j = 0; j < row; j++ )
	{
		for( i = 0; i < colum; i++ )
		{
			result[j][i] = a[j][i] + b[j][i];
		}
	}
}	 

void DMatSub( double *a[], double *b[], double *result[], int row, 
	int colum )							/* s̈Z */
{
	int i, j;

	for( j = 0; j < row; j++ )
	{
		for( i = 0; i < colum; i++ )
		{
			result[j][i] = a[j][i] - b[j][i];
		}
	}
}	 

int DMatMultiply( double *a[], double *b[], double *result[], 
	int row1, int colum1, int row2, int colum2, int row3, int colum3 )	/* ŝZ */
{
	int i, j, k, l;
	double temp;

	if( colum1 != row2 || row1 != row3 || colum2 != colum3 )
	{
		printf( "Size error\n" );
		return( 0 );
	}
	
	for( i = 0; i < row3; i++ )
	{
		for( j = 0; j < colum3; j++ )
		{
			temp = 0.0;
			for( k = 0; k < colum1; k++ )
			{
				temp += a[i][k] * b[k][j];
			}
			result[i][j] = temp;	
		}
	}
	return( 1 );
}

