#include "HandWrittenDigits.h"
// HandWrittenDigits.cpp

#include "HandWrittenDigits.h"
#include "GrayCode.h"

HandWrittenDigits :: HandWrittenDigits()
{
	m_nTotalPixels = 28 * 28;

	m_bTrainFiles = false;
	m_bTestFiles = false;

	m_pTrainImages = NULL;
	m_pTrainLabels = NULL;

	m_pTestImages = NULL;
	m_pTestLabels = NULL;

	m_sTrainImages = "train-images.idx3-ubyte";
	m_sTrainLabels = "train-labels.idx1-ubyte";

	m_sTestLabels = "t10k-labels.idx1-ubyte";
	m_sTestImages = "t10k-images.idx3-ubyte";
}

HandWrittenDigits :: ~HandWrittenDigits()
{
	CloseTrainFiles();
	CloseTestFiles();
}

void HandWrittenDigits::OpenTrainFiles()
{
	fopen_s(&m_pTrainImages, m_sTrainImages.c_str(), "rb");
	fopen_s(&m_pTrainLabels, m_sTrainLabels.c_str(), "rb");

	if (m_pTrainImages && m_pTrainLabels)
	{
		m_bTrainFiles = true;
	}
	else
	{
		if (!m_pTrainImages)
		{
			cout << "\n";
			cout << "Error - can't open Train Images for reading: ";
			cout << "\n";
			cout << m_sTrainImages;
			cout << "\n\n";
		}
		if (!m_pTrainLabels)
		{
			cout << "\n";
			cout << "Error - can't open Train Labels for reading: ";
			cout << "\n";
			cout << m_sTrainLabels;
			cout << "\n\n";
		}
	}
}

void HandWrittenDigits::ReadTrainHeaders()
{
	int a, n;
	int nMaxImageHeader = 16;		// from web site and file
	int nMaxLabelHeader = 8;		// from web site and file

	if (!m_bTrainFiles) return;
	n = 0;
	while ((a = fgetc(m_pTrainImages)) != EOF)
	{
		n++;
		if (n == nMaxImageHeader) break;
	}
	n = 0;
	while ((a = fgetc(m_pTrainLabels)) != EOF)
	{
		n++;
		if (n == nMaxLabelHeader) break;
	}
}

bool HandWrittenDigits::ReadTrainDigits(uint nSize, map<vector<uint>, vector<uint>>& m)
{
	if (m_bTrainFiles)
	{
		for (uint nIndex = 0; nIndex < nSize; nIndex++)
		{
			vector<uint> vIndexClass(2, 0);
			vIndexClass[0] = nIndex;

			uint a;
			if ((a = fgetc(m_pTrainLabels)) != EOF)
				vIndexClass[1] = a;

			uint b;
			uint nPixelCount = 0;
			vector<uint> vImage(m_nTotalPixels, 0);
			while ((b = fgetc(m_pTrainImages)) != EOF)
			{
				vImage[nPixelCount] = b;
				nPixelCount++;
				if (nPixelCount == m_nTotalPixels) break;	// stop at the end of this digit
			}
			m.insert(make_pair(vImage, vIndexClass));
			//ss << DumpVector(nSpace1, vImage) << "\t" << DumpVector(nSpace1, vIndexClass) << "\n";
		}
	}
	return true;
}

void HandWrittenDigits::CloseTrainFiles()
{
	if (m_pTrainImages != NULL) fclose(m_pTrainImages);
	if (m_pTrainLabels != NULL) fclose(m_pTrainLabels);
}

void HandWrittenDigits::GetTrainPixelsIndexClass(uint nSize, map<vector<uint>, vector<uint>>& m)
{
	OpenTrainFiles();
	ReadTrainHeaders();
	ReadTrainDigits(nSize, m);
	CloseTrainFiles();
}

void HandWrittenDigits::OpenTestFiles()
{
	fopen_s(&m_pTestImages, m_sTestImages.c_str(), "rb");
	fopen_s(&m_pTestLabels, m_sTestLabels.c_str(), "rb");

	if (m_pTestImages && m_pTestLabels)
	{
		m_bTestFiles = true;
	}
	else
	{
		if (!m_pTestImages)
		{
			cout << "\n";
			cout << "Error - can't open Test Images for reading: ";
			cout << "\n";
			cout << m_sTestImages;
			cout << "\n\n";
		}
		if (!m_pTestLabels)
		{
			cout << "\n";
			cout << "Error - can't open Test Labels for reading: ";
			cout << "\n";
			cout << m_sTestLabels;
			cout << "\n\n";
		}
	}
}

void HandWrittenDigits::ReadTestHeaders()
{
	int a, n;
	int nMaxImageHeader = 16;		// from web site and file
	int nMaxLabelHeader = 8;		// from web site and file

	if (!m_bTestFiles) return;
	n = 0;
	while ((a = fgetc(m_pTestImages)) != EOF)
	{
		n++;
		if (n == nMaxImageHeader) break;
	}
	n = 0;
	while ((a = fgetc(m_pTestLabels)) != EOF)
	{
		n++;
		if (n == nMaxLabelHeader) break;
	}
}

bool HandWrittenDigits::ReadTestDigits(uint nSize, map<vector<uint>, vector<uint>>& m)
{
	if (m_bTestFiles)
	{
		for (uint nIndex = 0; nIndex < nSize; nIndex++)
		{
			vector<uint> vIndexClass(2, 0);
			vIndexClass[0] = nIndex;

			uint a;
			if ((a = fgetc(m_pTestLabels)) != EOF)
				vIndexClass[1] = a;

			uint b;
			uint nPixelCount = 0;
			vector<uint> vImage(m_nTotalPixels, 0);
			while ((b = fgetc(m_pTestImages)) != EOF)
			{
				vImage[nPixelCount] = b;
				nPixelCount++;
				if (nPixelCount == m_nTotalPixels) break;	// stop at the end of this digit
			}
			m.insert(make_pair(vImage, vIndexClass));
			//ss << DumpVector(nSpace1, vImage) << "\t" << DumpVector(nSpace1, vIndexClass) << "\n";
		}
	}
	return true;
}

void HandWrittenDigits::CloseTestFiles()
{
	if (m_pTestImages != NULL) fclose(m_pTestImages);
	if (m_pTestLabels != NULL) fclose(m_pTestLabels);
}

void HandWrittenDigits::GetTestPixelsIndexClass(uint nSize, map<vector<uint>, vector<uint>>& m)
{
	OpenTestFiles();
	ReadTestHeaders();
	ReadTestDigits(nSize, m);
	CloseTestFiles();
}

void HandWrittenDigits::AddTrainingSizeToTestIndex(uint nIndexPosition, uint nTrainingSize, map<vector<bool>, vector<uint>>& m)
{
	map< vector<bool>, vector<uint> >::iterator i;
	for (i = m.begin(); i != m.end(); i++)
	{
		i->second[nIndexPosition] += nTrainingSize;
	}
}

void HandWrittenDigits::AddTrainingSizeToTestIndex(uint nIndexPosition, uint nTrainingSize, map<vector<uint>, vector<uint>>& m)
{
	map< vector<uint>, vector<uint> >::iterator i;
	for (i = m.begin(); i != m.end(); i++)
	{
		i->second[nIndexPosition] += nTrainingSize;
	}
}

vector<bool> HandWrittenDigits::GrayInputs(uint nMin, vector<uint> v)
{
	vector<bool> vGray(v.size(), false);
	for (uint n = 0; n < v.size(); n++)
	{
		if (v[n] > nMin)
			vGray[n] = true;
	}
	return vGray;
}

map<vector<bool>, vector<uint>> HandWrittenDigits::GrayInputs(uint nMin, map<vector<uint>, vector<uint>> m)
{
	GrayCode cGC = GrayCode();

	map< vector<bool>, vector<uint> > mGray;
	map< vector<uint>, vector<uint> >::iterator i;
	for (i = m.begin(); i != m.end(); i++)
	{
		vector<bool> vGray = GrayInputs(nMin, i->first);
		mGray.insert(make_pair(vGray, i->second));
	}
	return mGray;
}

vector<bool> HandWrittenDigits::LinearGrayInputs(uint nMaxBase10, vector<uint> v)
{
	GrayCode cGC = GrayCode();

	vector<bool> vGray;
	for (uint n = 0; n < v.size(); n++)
	{
		vector<bool> vGrayTemp(nMaxBase10, false);
		for (uint a = 0; a < v[n]; a++)
			vGrayTemp[a] = true;

		//reverse(vGrayTemp.begin(), vGrayTemp.end());
		vGray.insert(vGray.end(), vGrayTemp.begin(), vGrayTemp.end());
	}
	reverse(vGray.begin(), vGray.end());

	return cGC.GrayToBase2(vGray);
}

map<vector<bool>, vector<uint>> HandWrittenDigits::LinearGrayInputs(uint nMaxBase10, map<vector<uint>, vector<uint>> m)
{
	map< vector<bool>, vector<uint> > mGray;
	map< vector<uint>, vector<uint> >::iterator i;
	for (i = m.begin(); i != m.end(); i++)
	{
		vector<bool> vGray = LinearGrayInputs(nMaxBase10, i->first);
		mGray.insert(make_pair(vGray, i->second));
	}
	return mGray;
}

map<vector<bool>, vector<uint>> HandWrittenDigits::BinaryOutputIndex(uint nIndexPosition, uint nClassPosition, uint nIndex, map<vector<bool>, vector<uint>> m)
{
	map< vector<bool>, vector<uint> > mBinaryOutputIndex;
	map< vector<bool>, vector<uint> >::iterator i;
	for (i = m.begin(); i != m.end(); i++)
	{
		vector<uint> vIndexClass = i->second;
		if (vIndexClass.size() == 2)
		{
			if (vIndexClass[nIndexPosition] == nIndex)
			{
				vIndexClass[nClassPosition] = 1;
			}
			else
			{
				vIndexClass[nClassPosition] = 0;
			}
		}
		mBinaryOutputIndex.insert(make_pair(i->first, vIndexClass));
	}
	return mBinaryOutputIndex;
}

map<vector<bool>, vector<uint>> HandWrittenDigits::BinaryOutput(uint nClass, map<vector<bool>, vector<uint>> m)
{
	map< vector<bool>, vector<uint> > mBinaryOutput;
	map< vector<bool>, vector<uint> >::iterator i;
	for (i = m.begin(); i != m.end(); i++)
	{
		vector<uint> vIndexClass = i->second;
		if (vIndexClass.size() == 2)
		{
			if (vIndexClass[1] == nClass)
			{
				vIndexClass[1] = 1;
			}
			else
			{
				vIndexClass[1] = 0;
			}
		}
		mBinaryOutput.insert(make_pair(i->first, vIndexClass));
	}
	return mBinaryOutput;
}

map<vector<bool>, vector<uint>> HandWrittenDigits::BinaryOutput(uint nClassPosition, uint nClass, map<vector<bool>, vector<uint>> m)
{
	map< vector<bool>, vector<uint> > mBinaryOutput;
	map< vector<bool>, vector<uint> >::iterator i;
	for (i = m.begin(); i != m.end(); i++)
	{
		vector<uint> vIndexClass = i->second;
		if (nClassPosition < vIndexClass.size())
		{
			if (vIndexClass[nClassPosition] == nClass)
			{
				vIndexClass[1] = 1;
			}
			else
			{
				vIndexClass[1] = 0;
			}
		}
		mBinaryOutput.insert(make_pair(i->first, vIndexClass));
	}
	return mBinaryOutput;
}

map<vector<bool>, vector<uint>> HandWrittenDigits::FocusClass(uint nClass, map<vector<bool>, vector<uint>> m)
{
	map< vector<bool>, vector<uint> > mFocus;
	map< vector<bool>, vector<uint> >::iterator i;
	for (i = m.begin(); i != m.end(); i++)
	{
		vector<uint> vIndexClass = i->second;
		if (vIndexClass.size() == 2)
		{
			if (vIndexClass[1] == nClass)
			{
				mFocus.insert(make_pair(i->first, vIndexClass));
			}
		}
	}
	return mFocus;
}

bool HandWrittenDigits::IncrementBitCounts(vector<bool> vGray, vector<uint>& vBitCounts)
{
	if (vGray.size() != vBitCounts.size()) return false;

	for (uint n = 0; n < vGray.size(); n++)
	{
		if (vGray[n])
			vBitCounts[n]++;
	}
	return true;
}

vector<uint> HandWrittenDigits::GetBitCounts(uint nBits, map<vector<bool>, vector<uint> > mGray)
{	// only counting bits in gray format
	vector<uint> vBitCounts(nBits, 0);

	map<vector<bool>, vector<uint> > ::iterator i;
	for (i = mGray.begin(); i != mGray.end(); i++)
		IncrementBitCounts(i->first, vBitCounts);

	return vBitCounts;
}

vector<uint> HandWrittenDigits::BitCountOrder(vector<uint> vBitCounts)
{
	vector<uint> vCount(1, 0);
	map<uint, vector<uint>> mBitCountOrder;
	pair< map< uint, vector<uint> > ::iterator, bool > pr;
	for (uint n = 0; n < vBitCounts.size(); n++)
	{
		vector<uint> vCount(1, n);
		pr = mBitCountOrder.insert(make_pair(vBitCounts[n], vCount));
		if (!pr.second)
			pr.first->second.push_back(n);
	}

	vector<uint> vBitCountOrder;
	map< uint, vector<uint> >::iterator i;
	for (i = mBitCountOrder.begin(); i != mBitCountOrder.end(); i++)
	{
		vBitCountOrder.insert(vBitCountOrder.end(), i->second.begin(), i->second.end());
	}

	return vBitCountOrder;
}

//vector<uint> HandWrittenDigits::BitCountOrder(uint nBits, map<vector<bool>, vector<uint>> mGray)
//{
//	vector<uint> vBitCounts = GetBitCounts(nBits, mGray);
//	return BitCountOrder(vBitCounts);
//}

vector<uint> HandWrittenDigits::BitCountOrderBase2(uint nBits, map<vector<bool>, vector<uint>> mBase2)
{
	GrayCode cGC = GrayCode();
	map<vector<bool>, vector<uint>> mGray = cGC.Base2ToGray(mBase2);
	return BitCountOrderGray(nBits, mGray);
}

vector<uint> HandWrittenDigits::BitCountOrderGray(uint nBits, map<vector<bool>, vector<uint>> mGray)
{
	vector<uint> vBitCounts = GetBitCounts(nBits, mGray);
	return BitCountOrder(vBitCounts);
}

vector<bool> HandWrittenDigits::ReReOrder(vector<uint> vOrder, vector<bool> v)
{
	vector<bool> vReOrder(v.size(), false);
	if (vOrder.size() == v.size())
	{
		for (uint n = 0; n < vOrder.size(); n++)
		{
			uint nPos = vOrder[n];
			if (v[n])
			{
				vReOrder[nPos] = true;
			}
		}
	}
	return vReOrder;
}

vector<bool> HandWrittenDigits::ReOrderGray(vector<uint> vOrder, vector<bool> vGray)
{	// only makes sense if v is in GrayCode format as this changes the vBase2 order of the classifications
	vector<bool> vReOrderGray(vGray.size(), false);
	if (vOrder.size() == vGray.size())
	{
		for (uint n = 0; n < vOrder.size(); n++)
		{
			uint nPos = vOrder[n];
			if (vGray[nPos])
			{
				vReOrderGray[n] = true;
			}
		}
	}
	return vReOrderGray;
}

vector<bool> HandWrittenDigits::ReOrderBase2(vector<uint> vOrder, vector<bool> vBase2)
{
	GrayCode cGC = GrayCode();
	vector<bool> vGray = cGC.Base2ToGray(vBase2);
	vector<bool> vGrayReOrdered = ReOrderGray(vOrder, vGray);
	return cGC.GrayToBase2(vGrayReOrdered);
}

map<vector<bool>, vector<uint>> HandWrittenDigits::ReOrderBase2(vector<uint> vOrder, map<vector<bool>, vector<uint>> mBase2)
{
	map<vector<bool>, vector<uint>> mReOrdered;
	map<vector<bool>, vector<uint>>::iterator i;
	for (i = mBase2.begin(); i != mBase2.end(); i++)
	{
		vector<bool> vReOrdered = ReOrderBase2(vOrder, i->first);
		mReOrdered.insert(make_pair(vReOrdered, i->second));
	}
	return mReOrdered;
}

map<vector<bool>, vector<uint>> HandWrittenDigits::ReOrderGray(vector<uint> vOrder, map<vector<bool>, vector<uint>> mGray)
{
	map<vector<bool>, vector<uint>> mReOrderedGray;
	map<vector<bool>, vector<uint>>::iterator i;
	for (i = mGray.begin(); i != mGray.end(); i++)
	{
		vector<bool> vReOrdered = ReOrderGray(vOrder, i->first);
		mReOrderedGray.insert(make_pair(vReOrdered, i->second));
	}
	return mReOrderedGray;
}

uint HandWrittenDigits::GetFirstOne(vector<bool> v)
{
	for (uint n = 0; n < v.size(); n++)
	{
		if (v[n])
			return n;
	}
	return 0;
}

map<vector<bool>, vector<uint>> HandWrittenDigits::RemoveLeftZeros(uint& nFirstLeftOne, map<vector<bool>, vector<uint>> m)
{
	nFirstLeftOne = 0;
	map<vector<bool>, vector<uint>> mSansLeftZeros;
	map<vector<bool>, vector<uint>>::iterator i;
	if (m.size() > 0)
	{
		i = m.end();
		i--;
		nFirstLeftOne = GetFirstOne(i->first);
	}
	for (i = m.begin(); i != m.end(); i++)
	{
		vector<bool> v = i->first;
		v.erase(v.begin(), v.begin() + nFirstLeftOne);
		mSansLeftZeros.insert(make_pair(v, i->second));
	}

	return mSansLeftZeros;
}






