#include "stdafx.h"
#include "ControlFlowDiff.h"

namespace storm {

	// See if two functions are the same.
	static Bool sameFunction(ControlFlowItem o, ControlFlowItem n) {
		if (!o.isCall() || !n.isCall())
			return false;

		if (o.hasFunction() != n.hasFunction())
			return false;

		if (o.hasFunction()) {
			Function *oFun = o.function();
			Function *nFun = n.function();

			if (*oFun->name != *nFun->name)
				return false;

			if (oFun->params->count() != nFun->params->count())
				return false;

			for (Nat i = 0; i < oFun->params->count(); i++) {
				if (!nFun->params->at(i).mayStore(oFun->params->at(i)))
					return false;
			}

			return true;
		} else {
			// Built-in functions.
			return o.builtIn()->which() == n.builtIn()->which();
		}
	}

	// Flatten sub-lists:
	static Array<ControlFlowItem> *flatSubLists(Array<ControlFlowItem> *l) {
		Array<ControlFlowItem> *result = new (l) Array<ControlFlowItem>();
		result->reserve(l->count());

		for (Nat i = 0; i < l->count(); i++) {
			ControlFlowItem item = l->at(i);
			if (item.isCall()) {
				result->push(item);
			} else if (item.isLoop()) {
				result->push(ControlFlowItem(item.offset(), item.endOffset(), item.flatten()));
			}
		}

		return result;
	}

	// Element to use in the DP table of the minimum edit distance algorithm:
	struct DpElem {
		enum {
			none,
			add,
			remove,
			keep
		};

		// Keep within 32 bits to keep it compact.
		Nat distance : 30;
		Nat mode : 2;

		// Default ctor:
		DpElem() : distance(0), mode(none) {}
		DpElem(Nat distance, Nat mode) : distance(distance), mode(mode) {}

		// Update if better distance:
		void update(Nat distance, Nat mode) {
			if (this->distance >= distance) {
				this->distance = distance;
				this->mode = mode;
			}
		}
	};

	// Check if vector contains element.
	static bool contains(const vector<Nat> &in, Nat find) {
		// Note: If performance is a problem, we can use lower_bound. 'in' is sorted here.
		return std::find(in.begin(), in.end(), find) != in.end();
	}

	// Compute the minimum edit distance of two items that are known to be loops, and flattened.
	static Nat editDistance(Array<ControlFlowItem> *a, Array<ControlFlowItem> *b) {
		// DP-table: cell [a][b] is the minimum edit distance at when we have
		// transformed the first 'a' cells in A into the first 'b' cells in B.
		vector<vector<Nat>> dpTable(a->count() + 1, vector<Nat>(b->count() + 1, 0));

		// Fill in row 1 and col 1:
		for (Nat aPos = 1; aPos <= a->count(); aPos++)
			dpTable[aPos][0] = aPos;
		for (Nat bPos = 1; bPos <= b->count(); bPos++)
			dpTable[0][bPos] = bPos;

		// Fill in the rest of the table:
		for (Nat aPos = 1; aPos <= a->count(); aPos++) {
			for (Nat bPos = 1; bPos <= b->count(); bPos++) {
				// Options are:
				// 1: Deleting element in 'a':
				Nat best = dpTable[aPos - 1][bPos] + 1;
				// 2: Inserting element from 'b':
				best = min(best, dpTable[aPos][bPos - 1] + 1);
				// 3: if possible, just leave it alone:
				if (sameFunction(a->at(aPos - 1), b->at(bPos - 1)))
					best = min(best, dpTable[aPos - 1][bPos - 1]);

				dpTable[aPos][bPos] = best;
			}
		}

		// Result is in the last cell:
		return dpTable.back().back();
	}

	static vector<DpElem> minEdits(const vector<vector<Nat>> &equivalence, Nat oldCount, Nat newCount) {
		// See 'editDistance' above for details.
		vector<vector<DpElem>> dpTable(oldCount + 1, vector<DpElem>(newCount + 1, DpElem()));

		// Fill in row 1 and col 1:
		for (Nat oldPos = 1; oldPos <= oldCount; oldPos++)
			dpTable[oldPos][0] = DpElem(oldPos, DpElem::remove);
		for (Nat newPos = 1; newPos <= newCount; newPos++)
			dpTable[0][newPos] = DpElem(newPos, DpElem::add);

		// Fill in the rest of the table:
		for (Nat oldPos = 1; oldPos <= oldCount; oldPos++) {
			for (Nat newPos = 1; newPos <= newCount; newPos++) {
				// 1: Delete element in 'old':
				DpElem best(dpTable[oldPos - 1][newPos].distance + 1, DpElem::remove);
				// 2: Insert an element in 'new':
				best.update(dpTable[oldPos][newPos - 1].distance + 1, DpElem::add);
				// 3: Keep current one, if applicable:
				if (contains(equivalence[oldPos - 1], newPos - 1))
					best.update(dpTable[oldPos - 1][newPos - 1].distance, DpElem::keep);

				dpTable[oldPos][newPos] = best;
			}
		}

		// Extract the minimum edits:
		vector<DpElem> edits;
		Nat oldPos = oldCount;
		Nat newPos = newCount;
		while (oldPos != 0 || newPos != 0) {
			DpElem &elem = dpTable[oldPos][newPos];
			edits.push_back(elem);
			switch (elem.mode) {
			case DpElem::remove:
				oldPos--;
				break;
			case DpElem::add:
				newPos--;
				break;
			case DpElem::keep:
				newPos--;
				oldPos--;
				break;
			case DpElem::none:
				// Exit loop, this should not happen.
				oldPos = 0;
				newPos = 0;
				break;
			}
		}

		std::reverse(edits.begin(), edits.end());

		return edits;
	}


	// Compute a diff between two CFLs. Return a ControlFlowItem that has the same structure as
	// 'oldFlow' that indicates what each element should be mapped to.
	Array<ControlFlowItem> *diff(Array<ControlFlowItem> *oldFlow, Array<ControlFlowItem> *newFlow) {
		Engine &e = oldFlow->engine();

		// Create flattened version of both lists:
		Array<ControlFlowItem> *oldFlat = flatSubLists(oldFlow);
		Array<ControlFlowItem> *newFlat = flatSubLists(newFlow);

		// Pre-compute an equivalence relation:
		vector<vector<Nat>> equivalence;
		for (Nat i = 0; i < oldFlat->count(); i++) {
			ControlFlowItem item = oldFlat->at(i);
			if (item.isCall()) {
				// Find all elements with the same name and parameter lists:
				vector<Nat> eq;
				for (Nat j = 0; j < newFlat->count(); j++) {
					if (sameFunction(item, newFlat->at(j)))
						eq.push_back(j);
				}
				equivalence.push_back(eq);
			} else if (item.isLoop()) {
				// Find all elements with minimum edit distance:
				vector<Nat> eq;
				Nat minEdit = nat(-1);
				for (Nat j = 0; j < newFlat->count(); j++) {
					if (!newFlat->at(j).isLoop())
						continue;

					Nat edit = editDistance(item.loop(), newFlat->at(j).loop());
					if (edit < minEdit) {
						minEdit = edit;
						eq.clear();
					}
					if (edit == minEdit) {
						eq.push_back(j);
					}
				}
				equivalence.push_back(eq);
			}
		}

		// Compute the minimum edit distance:
		vector<DpElem> edits = minEdits(equivalence, oldFlat->count(), newFlat->count());

		// for (size_t i = 0; i < edits.size(); i++) {
		// 	PVAR(edits[i].distance);
		// 	PVAR(edits[i].mode);
		// }

		// Figure out the mapping:
		Array<ControlFlowItem> *result = new (e) Array<ControlFlowItem>();

		Nat oldPos = 0;
		Nat newPos = 0;
		for (size_t i = 0; i < edits.size(); i++) {
			if (edits[i].mode == DpElem::remove) {
				// This element has been removed, map it to the previous emitted item, if there is one.
				if (result->any()) {
					result->push(result->last());
				} else {
					result->push(ControlFlowItem());
				}
				result->last().status(ControlFlowItem::removed);

				oldPos++;
			} else if (edits[i].mode == DpElem::add) {
				// The element has been added. Just advance 'newPos'!
				newPos++;

			} else if (edits[i].mode == DpElem::keep) {
				// Element is the same, just map it to the new one!
				const ControlFlowItem &o = oldFlow->at(oldPos);
				const ControlFlowItem &n = newFlow->at(newPos);
				assert(o.isCall() == n.isCall());

				if (o.isCall()) {
					result->push(n);
				} else if (o.isLoop()) {
					// Recursive diff.
					result->push(ControlFlowItem(n.offset(), n.endOffset(), diff(o.loop(), n.loop())));
				}

				oldPos++;
				newPos++;
			}
		}

		// This should always be true, otherwise we have messed up some case above or in 'minEdits'.
		assert(result->count() == oldFlow->count());

		return result;
	}


	static void formatDiff(StrBuf *to, Array<ControlFlowItem> *oldFlow, Array<ControlFlowItem> *newFlow) {
		if (oldFlow->count() != newFlow->count()) {
			*to << S("<size mismatch: ") << oldFlow->count() << S(" vs ") << newFlow->count() << S(")");
			return;
		}

		*to << S("[");
		to->indent();

		for (Nat i = 0; i < oldFlow->count(); i++) {
			const ControlFlowItem &o = oldFlow->at(i);
			const ControlFlowItem &n = newFlow->at(i);

			*to << S("\n@") << o.offset() << S(",") << o.endOffset() << S("->")
				<< n.offset() << S(",") << n.endOffset() << S(" ");

			if (o.isLoop() && n.isLoop()) {
				formatDiff(to, o.loop(), n.loop());
			} else if (o.isCall() && n.isCall()) {
				if (o.hasFunction())
					*to << o.function()->identifier();
				else
					*to << o.builtIn()->title();
				*to << S(" -> ");
				if (n.hasFunction())
					*to << n.function()->identifier();
				else
					*to << n.builtIn()->title();
			} else if (o.isCall() && n.isStart()) {
				if (o.hasFunction())
					*to << o.function()->identifier();
				else
					*to << o.builtIn()->title();
				*to << S(" -> <start of function>");
			} else {
				*to << S("<non-displayable combination>");
			}

			ControlFlowItem::statusSuffix(to, n.status());
		}

		to->dedent();
		*to << S("\n]");
	}

	Str *formatDiff(Array<ControlFlowItem> *oldFlow, Array<ControlFlowItem> *newFlow) {
		StrBuf *out = new (oldFlow) StrBuf();
		formatDiff(out, oldFlow, newFlow);
		return out->toS();
	}

}
