#include "StdAfx.h"
#include "SSLProtocolSocket.h"
#include "SSLOverlappedSocketBIO.h"


CSSLProtocolSocket::CSSLProtocolSocket(void)
{
	m_sslInit = 0;
	m_ssl = NULL;
}

CSSLProtocolSocket::~CSSLProtocolSocket(void)
{
	Close();
}


//////////////////////////////////////////////////////////////////////////////////
//	Ȃ
//////////////////////////////////////////////////////////////////////////////////
/*!
	
*/
void CSSLProtocolSocket::Close()
{
	//	eNX
	CProtocolSocket::Close();

	//	
	if(m_sslInit)
	{
		SSL_shutdown(m_ssl);
		SSL_free(m_ssl);
	}
	m_sslInit = 0;

	//	̃Xbh̃G[j
	ERR_remove_state(0);
}

//////////////////////////////////////////////////////////////////////////////////
//	ʐM
//////////////////////////////////////////////////////////////////////////////////
/*!
	SSLʐMJn
*/
int CSSLProtocolSocket::SSLAccept(SSL_CTX *sslContext)
{
	//	ďH
	ASSERT(m_sslInit == 0);
	m_sslInit = 1;

	//	OverlappedSocketΉBIO𐶐
	m_bio = BIO_new_ovrs(&m_socket);
	if(m_bio == NULL)
	{
		CloseError(CSPS_ERROR_CREATE_SSL_CONNECTION, _T("SSLAccept"));
		return(-1);
	}

	//	SSL
	m_ssl = SSL_new(sslContext);
	if(m_ssl == NULL)
	{
		CloseError(CSPS_ERROR_CREATE_SSL_CONNECTION, _T("SSLAccept"));
		return(-1);
	}

	//	BIO蓖
	SSL_set_bio(m_ssl, m_bio, m_bio);

	//	lSVG[V
	int	sslErr = SSL_get_error(m_ssl, SSL_accept(m_ssl));
	if(sslErr == SSL_ERROR_ZERO_RETURN)
	{
		//	SSLvgRŐؒfm
		CloseError(CBS_ERROR_CLOSE, "SSLAccept");
		return(-1);
	}
	else if(sslErr == SSL_ERROR_SSL)
	{
		//	SSLvgRG[
		CloseError(CSPS_ERROR_SSL_ACCEPT, "SSLAccept");
		return(-1);
	}
	else if(sslErr != SSL_ERROR_NONE)
	{
		switch(m_socket.GetLastError())
		{
		case COverlappedSocket::OVRS_ERROR_TIMEOUT:
			CloseError(CBS_ERROR_TIMEOUT, _T("SSLAccept"));
			break;

		case COverlappedSocket::OVRS_ERROR_CLOSE:
			CloseError(CBS_ERROR_CLOSE, _T("SSLAccept"));
			break;

		case COverlappedSocket::OVRS_ERROR_BREAK:
			CloseError(CBS_ERROR_BREAK, _T("SSLAccept"));
			break;

		case COverlappedSocket::OVRS_ERROR_OTHER:
			CloseError(CBS_ERROR_OTHER, _T("SSLAccept"));
			break;

		case COverlappedSocket::OVRS_ERROR_NO_ERROR:
		default:
			CloseError(CSPS_ERROR_SSL_OTHER, _T("SSLAccept"));
			break;
		}
		return(-1);
	}
	return(0);
}

//////////////////////////////////////////////////////////////////////////////////
//	ʐM
//////////////////////////////////////////////////////////////////////////////////
/*!
	M(CBufferdSocketI[o[Ch)
*/
int CSSLProtocolSocket::BlockingSend(int timeOut)
{
	//	`FbN
	if(!m_sslInit)
	{
		CloseError(CBS_ERROR_CLOSE, _T("BlockingSend"));
		return(-1);
	}

	//	f[^?
	if(m_sendBuf.GetInBuf() == 0)
		return(0);

	//	^CAEgw
	m_socket.SetTimeOut(timeOut);

	//	Mpobt@擾
	int		bufLen;
	char	*sendBuf = m_sendBuf.GetReadBuffer(&bufLen);

	//	M
	int sended = SSL_write(m_ssl, sendBuf, bufLen);

	//	G[`FbN
	int	sslErr = SSL_get_error(m_ssl, sended);
	if(sslErr == SSL_ERROR_ZERO_RETURN)
	{
		//	SSLvgRŐؒfm
		CloseError(CBS_ERROR_CLOSE, "BlockingSend");
		return(-1);
	}
	else if(sslErr != SSL_ERROR_NONE)
	{
		switch(m_socket.GetLastError())
		{
		case COverlappedSocket::OVRS_ERROR_TIMEOUT:
			CloseError(CBS_ERROR_TIMEOUT, _T("BlockingSend"));
			break;

		case COverlappedSocket::OVRS_ERROR_CLOSE:
			CloseError(CBS_ERROR_CLOSE, _T("BlockingSend"));
			break;

		case COverlappedSocket::OVRS_ERROR_BREAK:
			CloseError(CBS_ERROR_BREAK, _T("BlockingSend"));
			break;

		case COverlappedSocket::OVRS_ERROR_OTHER:
			CloseError(CBS_ERROR_OTHER, _T("BlockingSend"));
			break;

		case COverlappedSocket::OVRS_ERROR_NO_ERROR:
		default:
			CloseError(CSPS_ERROR_SSL_OTHER, _T("BlockingSend"));
			break;
		}
		return(-1);
	}

	//	Mf[^ʂݒ
	m_sendBuf.SetReadedLen(sended);
	return(m_sendBuf.GetInBuf());
}


/*!
	M(CBufferdSocketI[o[Ch)
*/
int CSSLProtocolSocket::BlockingRecv(int timeOut)
{
	//	`FbN
	if(!m_sslInit)
	{
		CloseError(CBS_ERROR_CLOSE, "TryRecv");
		return(-1);
	}

	//	obt@邩H
	if(m_recvBuf.GetFreeBuf() == 0)
		return(m_recvBuf.GetInBuf());	//	f[^Mς݂ȂI

	//	^CAEgw
	m_socket.SetTimeOut(timeOut);

	//	Mpobt@擾
	int		bufLen;
	char	*recvBuf = m_recvBuf.GetWriteBuffer(&bufLen);

	//	M
	int recved = SSL_read(m_ssl, recvBuf, bufLen);

	//	G[`FbN
	int	sslErr = SSL_get_error(m_ssl, recved);
	if(sslErr == SSL_ERROR_ZERO_RETURN)
	{
		//	SSLvgRŐؒfm
		CloseError(CBS_ERROR_CLOSE, "BlockingSend");
		return(-1);
	}
	else if(sslErr != SSL_ERROR_NONE)
	{
		switch(m_socket.GetLastError())
		{
		case COverlappedSocket::OVRS_ERROR_TIMEOUT:
			CloseError(CBS_ERROR_TIMEOUT, "BlockingRecv");
			break;

		case COverlappedSocket::OVRS_ERROR_CLOSE:
			CloseError(CBS_ERROR_CLOSE, "BlockingSend");
			break;

		case COverlappedSocket::OVRS_ERROR_BREAK:
			CloseError(CBS_ERROR_BREAK, "BlockingRecv");
			break;

		case COverlappedSocket::OVRS_ERROR_OTHER:
			CloseError(CBS_ERROR_OTHER, "BlockingRecv");
			break;

		case COverlappedSocket::OVRS_ERROR_NO_ERROR:
		default:
			CloseError(CSPS_ERROR_SSL_OTHER, "BlockingRecv");
			break;
		}
		return(-1);
	}

	//	Mf[^ʂݒ
	m_recvBuf.SetWritedLen(recved);
	return(m_recvBuf.GetInBuf());
}


//////////////////////////////////////////////////////////////////////////////////
//	G[
//////////////////////////////////////////////////////////////////////////////////
/*!
	G[R[h當擾
*/
CString CSSLProtocolSocket::GetErrorString(int errorCode)
{
	switch(errorCode)
	{
	case CSPS_ERROR_CREATE_SSL_CONNECTION:
		return(_T("SSLڑ̏Ɏs܂"));

	case CSPS_ERROR_SSL_ACCEPT:
		return(_T("SSLڑ̃lSVG[VɎs܂"));

	case CSPS_ERROR_SSL_OTHER:
		return(_T("SSLʐMvgRɊ֘AG[܂"));

	default:
		return(CProtocolSocket::GetErrorString(errorCode));
	}
}
