#!/usr/bin/env python
VERSION = "1.4.3"

import nxs
import numpy
import math
import os
import sys

NAN = float("nan")

class NXSdata:
    def __init__(self, nxs, path, **kwargs):
        nxs.openpath(path)
        (self.dims, self.type) = nxs.getinfo()
        if len(self.dims) == 1:
            self.dims=self.dims[0]
        self.attrs = nxs.getattrs()
        self.data = nxs.getdata()

def getDiffSym(diffType, format):
    sumMap = {Diff.SAME:Diff.SAME, Diff.NEWLEFT:'-', Diff.NEWRIGHT:'+',
              Diff.DIFF:Diff.DIFF, Diff.UNKNOWN:Diff.UNKNOWN}
    if format == "unified":
        return sumMap[diffType]
    else:
        return diffType

def getPercentDiff(left, right, nandiff=float("nan")):
    # find the difference (absolute)
    diffs = left-right
    combine = numpy.fabs(left) + numpy.fabs(right)

    # do the percent part
    diffs = 100. * numpy.fabs(diffs/left)
    try:
        diffs[0]
        return diffs
    except IndexError:
        return [diffs]

class Diff:
    SAME     = " "
    NEWLEFT  = "<"
    NEWRIGHT = ">"
    DIFF     = "|"
    UNKNOWN  = "?"
    def __init__(self, path, **kwargs):
        self.path = path
        try:
            left = kwargs["left"]
        except KeyError:
            left = None
        try:
            right = kwargs["right"]
        except KeyError:
            right = None
        try:
            self.rightpath = kwargs["rightpath"]
        except KeyError:
            self.rightpath = None
        try:
            epsilon = kwargs["epsilon"]
        except KeyError:
            epsilon = 0.0000001
        temp = self.path.split(" ")
        if len(temp) == 2:
            (self.summary, self.path) = temp
        try:
            self.summary = kwargs["diff"]
            if self.summary == "<" or self.summary == "-":
                self.summary = Diff.NEWLEFT
            elif self.summary == ">" or self.summary == "+":
                self.summary = Diff.NEWRIGHT
            else:
                raise "Do not understand diff \"%s\"" % self.summary
        except KeyError:
            self.summary = Diff.UNKNOWN
        try:
            self.setFormat(kwargs["format"])
        except KeyError:
            self.setFormat("standard")
        self.details = []
        if self.summary != Diff.UNKNOWN:
            return

        self.__cmpData(left, right, self.rightpath, epsilon)
        if len(self.details) > 0:
            self.summary = Diff.DIFF
        else:
            self.summary = Diff.SAME

    def __shortPath(self):
        mypath= self.path.split("/")
        mypath = [item.split(":")[0] for item in mypath]
        return '/'.join(mypath)

    def __str__(self):
        result = "%s %s" % (getDiffSym(self.summary, self.__format), \
                            self.__shortPath())
        if self.summary == Diff.DIFF:
            result += " " + " ".join(self.details)
        return result

    def __cmpData(self, left, right, rightpath, epsilon):
        left = left.getData(self.path)
        if rightpath is not None:
            right = right.getData(rightpath)
        else:
            right = right.getData(self.path)

        if left.type != right.type:
            # /entry/raw_frames changed from int32 to uint64. Also see below.
            if (self.path.find("raw_frames") != -1 and left.type == "int32" \
                and right.type == "uint64") or (self.path.find("raw_frames") != -1 \
                and left.type == "uint64" and right.type == "int32"):
                pass
            # See below
            elif (left.type == "float32" and right.type == "float64") or \
            (left.type == "float64" and right.type == "float32"):
                pass
            else:
                self.details.append("TYPE MISMATCH: %s != %s" \
                                % (left.type, right.type))
        if left.attrs != right.attrs:
            self.details.append("ATTRIBUTES MISMATCH: %s != %s" % \
                                (left.attrs, right.attrs))
        if left.type == "CHAR" or left.type == "char" \
               or right.type =="CHAR" or right.type == "char":
            if left.data != right.data:
                self.details.append("DATA MISMATCH: %s != %s" % \
                                    (left.data, right.data))
            return
        if not numpy.alltrue(left.dims == right.dims):
            self.details.append("DIMENSION MISMATCH: %s != %s" \
                                % (left.dims, right.dims))
            return
        else:
            try:
                if left.data.dtype != right.data.dtype:
                    # Display /entry/raw_frames after changed from int32 to uint64.
                    if (self.path.find("raw_frames") != -1 and left.type == "int32" \
                        and right.type == "uint64") or (self.path.find("raw_frames") != -1 \
                        and left.type == "uint64" and right.type == "int32"):
                        self.details.append("IGNORED TYPE MISMATCH: %s != %s" \
                                   % (left.type, right.type))
                        self.details.append("LEFT DATA: " + str(left.data))
                        self.details.append("RIGHT DATA: " + str(right.data))
                    elif (left.type == "float32" and right.type == "float64") or \
                        (left.type == "float64" and right.type == "float32"):
                        # cast float32 to float64
                        if left.type == "float32":
                            if (left.dims == 1):
                                if (abs(abs(numpy.float64(left.data)) - abs(right.data)) <= \
                                    max(abs(numpy.float64(left.data)), abs(right.data)) * epsilon):
                                    # equal within precision epsilon
                                    pass
                                else:
                                    self.details.append("DATA MISMATCH: %s %s != %s %s" \
                                        % (left.data, left.type, right.data, right.type))
                            else:
                                for i in range (left.dims):
                                    if (abs(abs(numpy.float64(left.data[i])) - abs(right.data[i])) <= \
                                        max(abs(numpy.float64(left.data[i])), abs(right.data[i])) * epsilon):
                                        # equal within precision epsilon
                                        pass
                                    else:
                                        self.details.append("DATA MISMATCH: %s %s != %s %s" \
                                            % (left.data[i], left.type, right.data[i], right.type))
                        if right.type == "float32":
                            if (left.dims == 1):
                                if (abs(abs(left.data) - abs(numpy.float64(right.data))) <= \
                                    max(abs(left.data), abs(numpy.float64(right.data))) * epsilon):
                                    # equal within precision epsilon
                                    pass
                                else:
                                    self.details.append("DATA MISMATCH: %s %s != %s %s" \
                                        % (left.data, left.type, right.data, right.type))
                            else:
                                for i in range (left.dims):
                                    if (abs(abs(left.data[i]) - abs(numpy.float64(right.data[i]))) <= \
                                        max(abs(left.data[i]), abs(numpy.float64(right.data[i]))) * epsilon):
                                        # equal within precision epsilon
                                        pass
                                    else:
                                        self.details.append("DATA MISMATCH: %s %s != %s %s" \
                                            % (left.data[i], left.type, right.data[i], right.type))
                    else:
                        self.details.append("DATA MISMATCH: %s != %s" \
                                        % (left.data.dtype.name, \
                                           right.data.dtype.name))
                    return
            except TypeError:
                pass
            except AttributeError:
                pass

        # special case when the array is only nans
        if numpy.alltrue(numpy.isnan(left.data)) and \
           numpy.alltrue(numpy.isnan(right.data)):
            return

        if numpy.alltrue(left.data == right.data):
            return
        else:
            if left.data.size == 1 and right.data.size == 1:
                self.details.append("DATA MISMATCH %s != %s" \
                                    % (str(left.data), str(right.data)))
                return
            diffs = getPercentDiff(left.data, right.data)
            if numpy.nanmax(diffs) <= 0.:
                return
            stats = getStats(diffs)
            self.details.append("MISMATCH [min%s,max%s,med%s,avg%s,dev%s]" \
                                % (stats[0], stats[1], stats[2], stats[3],
                                   stats[4]))
    def setFormat(self, format):
        if format == "unified":
            self.__format = "unified"
        else:
            self.__format = "standard"

class NXSfile:
    def __init__(self, name, **kwargs):
        try:
            self.__ignorelinks = kwargs["ignorelinks"]
        except KeyError:
            self.__ignorelinks = False
        try:
            self.__ignorenotes = kwargs["ignorenotes"]
        except KeyError:
            self.__ignorenotes = False
        try:
            self.__ignorets = kwargs["ignorets"]
        except KeyError:
            self.__ignorets = False
        try:
            self.__retype = kwargs["retype"]
        except KeyError:
            self.__retype = []
        try:
            self.__epsilon = kwargs["epsilon"]
        except KeyError:
            self.__epsilon = 0.0000001
        try:
            self.__ignorerunnumber = kwargs["ignorerunnumber"]
        except KeyError:
            self.__ignorerunnumber = False
        self.name = os.path.abspath(name)
        self.__initPaths()

    def __initPaths(self, **kwargs):
        self.__paths = []
        self.__nxs = nxs.open(self.name)
        self.__paths = self.__getPaths()
        self.__paths.sort()
        while '/' in self.__paths:
            del self.__paths[self.__paths.index('/')]

    def __getPaths(self, **kwargs):
        result = []
        #result.append(self.__nxs.longpath) # add the groups to the tree
        parent = self.__nxs.longpath.split('/')[-1]
        listing = self.__nxs.getentries()
        for name in listing.keys():
            nxclass = listing[name]
            if nxclass == "SDS":
                if self.__ignorenotes and parent.endswith("NXnote"):
                    if name == "author" or name == "date":
                        continue
                elif self.__ignorets and name == "SNStranslation_service":
                    continue
                elif self.__ignorerunnumber and (name == "entry_identifier" \
                    or name == "run_number"):
                    continue
                self.__nxs.opendata(name)
                attrs = self.__nxs.getattrs()
                longpath = self.__nxs.longpath
                shortpath = self.__nxs.path
                self.__nxs.closedata()
                if self.__ignorelinks and attrs.has_key("target"):
                    target = attrs["target"]
                    if target != shortpath:
                        continue
                result.append(longpath)
            else:
                self.__nxs.opengroup(name,nxclass)
                result.extend(self.__getPaths())
                self.__nxs.closegroup()
        return result

    def __str__(self):
        return self.name

    def __eq__(self, other):
        """This provides only the simplest of cases for comparing equality"""
        return self.name == other.name

    def __getMissing(self, left, right):
        result = filter(lambda a, right=right:not a in right, left)
        return result

    def __removeDuplicate(self, item, array):
        if item not in array:
            return
        index = array.index(item)
        try:
            while True:
                index = array.index(item, index+1)
                del array[index]
        except ValueError:
            pass
        
    def cmpPaths(self, other, **kwargs):
        result = []
        if self.__paths == other.__paths:
            diffPaths = self.__paths[:]
        else:
            # find what one has that the other is missing
            missingOther = self.__getMissing(self.__paths, other.__paths)
            missingSelf = self.__getMissing(other.__paths, self.__paths)

            # merge the results
            diffPaths = self.__paths[:]
            diffPaths.extend(other.__paths)
            diffPaths.sort()
            for path in diffPaths:
                self.__removeDuplicate(path, diffPaths)

            number = len(self.__retype)
            #print number
            # handle retype
            if number > 0:
                lnames = []
                rnames = []
                for item in self.__retype:
                    if item.find("=") != -1:
                        # TODO: add check for "NX*" if limit to type only
                        (lname, rname) = item.split("=")
                        lnames.append(lname)
                        rnames.append(rname)
                #print lnames
                #print rnames
                for path in missingSelf[:]:
                    otherpath = path
                    for i in range(number):
                        if path.find(rnames[i]) != -1:
                            otherpath = otherpath.replace(rnames[i], lnames[i])
                    #print otherpath
                    if otherpath in missingOther:
                    # Found corresponding path, otherwise do nothing
                        #print "Found " + otherpath + " in missingOther"
                        result.append(Diff(otherpath, left=self, right=other, rightpath=path, epsilon=self.__epsilon))
                        missingSelf.remove(path)
                        missingOther.remove(otherpath)
                        diffPaths.remove(path)
                        diffPaths.remove(otherpath)
                for path in missingOther[:]:
                    otherpath = path
                    for i in range(number):
                        if path.find(lnames[i]) != -1:
                            otherpath = otherpath.replace(lnames[i], rnames[i])
                    #print otherpath
                    if otherpath in missingSelf:
                    # Found corresponding path, otherwise do nothing
                        #print "Found " + otherpath + " in missingSelf"
                        result.append(Diff(path, left=self, right=other, rightpath=otherpath, epsilon=self.__epsilon))
                        missingSelf.remove(otherpath)
                        missingOther.remove(path)
                        diffPaths.remove(path)
                        diffPaths.remove(otherpath)

            for path in missingSelf:
                index = diffPaths.index(path)
                diffPaths[index] = "> %s" % diffPaths[index]
            for path in missingOther:
                index = diffPaths.index(path)
                diffPaths[index] = "< %s" % diffPaths[index]

        for path in diffPaths:
            if path.startswith("<") or path.startswith(">"):
                result.append(Diff(path, diff=path[0]))
            else:
                result.append(Diff(path, left=self, right=other, epsilon=self.__epsilon))

        return result

    def getData(self, path, **kwargs):
        return NXSdata(self.__nxs, path)

def isNaN(x):
    if x * 1.0 < x:
        return True
    return (x == 1.0) and (x == 2.0)

def removeNaN(array):
    hasnans = numpy.isnan(array)
    indices = numpy.where(hasnans == True)[0]
    if len(indices) <= 0:
        return array
    return numpy.delete(array, indices.tolist())
    try:
        len(myarray)
        return myarray
    except TypeError:
        return [myarray]

def getStats(array, **kwargs):
    myarray = numpy.copy(array)
    myarray = myarray.ravel()
    origLength = myarray.size
    myarray = removeNaN(myarray)
    myarray.sort()
    length = myarray.size
    if length ==0 and length < origLength:
        avg = NAN
        minimum = NAN
        maximum  = NAN
        median = NAN
        stddev = NAN
    else:
        avg = numpy.average(myarray)
        minimum = numpy.nanmin(myarray)
        maximum = numpy.nanmax(myarray)
        median = myarray[(length/2)-1]
        stddev = myarray.std()
    result = map(lambda x: "%.2f%%" % x,
                 (minimum, maximum, median, avg, stddev))
    result.append(origLength) # add the number of elements
    result.append(origLength - length) # number of NaNs found
    return result

def cmpData(left, right, path, **kwargs):
    left = left.getData(path)
    right = right.getData(path)
    if left.type != right.type:
        return "TYPE MISMATCH: %s != %s" % (left.type, right.type)
    if left.dims != right.dims:
        return "DIMENSION MISMATCH: %s != %s" % (left.dims, right.dims)
    if left.attrs != right.attrs:
        return "ATTRIBUTES MISMATCH: %s != %s" % (left.attrs, right.attrs)
    if left.type == "CHAR":
        if left.data == right.data:
            return ""
        else:
            return "DATA MISMATCH: %s != %s" % (left.data, right.data)
    if utils.vector_is_equals(left.data, right.data):
        return ""
    else:
        diffs = []
        import math
        for i in range(len(left.data)):
            diffs.append(math.fabs((left.data[i]-right.data[i])/left.data[i]))
        stats = getStats(diffs)
        return "MISMATCH [min%s,max%s,med%s,avg%s,dev%s]" \
               % (stats[0], stats[1], stats[2], stats[3], stats[4])

def delinearIndex(linear, dims):
    length = len(dims)
    if length == 1:
        return linear
    elif length == 2:
        index = [0,0]
        index[1] = linear % dims[1]
        index[0] = (linear - index[1])/dims[1]
        return tuple(index)
    else:
        raise "Do not know how to deal with dimension " + length

def printDataDiff(left, right, diff, **kwargs):
    # determine the symbols for right and left
    try:
        format = kwargs["format"]
    except KeyError:
        format = "standard"
    leftSym = getDiffSym(Diff.NEWLEFT, format)
    rightSym = getDiffSym(Diff.NEWRIGHT, format)

    # get the data
    left = left.getData(diff.path)
    # use diff.rightpath for --retype
    if diff.rightpath is not None:
        right = right.getData(diff.rightpath)
    else:
        right = right.getData(diff.path)

    # the dimensions of the data
    leftLength = 1
    try:
        for dim in left.dims:
            leftLength *= dim
    except TypeError:
        leftLength *= left.dims
    rightLength = 1
    try:
        for dim in right.dims:
            rightLength *= dim
    except TypeError:
        rightLength *= right.dims

    # the threshold for printing values
    try:
        threshold = kwargs["threshold"]
    except KeyError:
        threshold = None

    try:
        shownandiffs = kwargs["shownandiffs"]
    except KeyError:
        shownandiffs = False

    # how many items to show
    if not threshold:
        try:
            numItems = kwargs["numitems"]
        except KeyError:
            numItems = 0
        if numItems < 0:
            pass
        elif numItems == 0:
            leftLength = 10
            rightLength = 10
        else:
            leftLength = numItems
            rightLength = numItems


    # print out type and dims
    if (left.type == right.type) and (left.dims == right.dims):
        print left.type, left.dims
    else:
        print leftSym, left.type, left.dims
        print rightSym, right.type, right.dims

    if left.attrs == right.attrs:
        print left.attrs
    else:
        print leftSym, left.attrs
        print rightSym, right.attrs

    if threshold is None:
        try:
            print leftSym, left.data.__str__(last=leftLength)
        except TypeError:
            print str(left.data)
        try:
            print rightSym, right.data.__str__(last=rightLength)
        except TypeError:
            print str(right.data)
    else:
        nanIndices = []
        changeInfo = []
        length = min(leftLength, rightLength)
        myDiff = getPercentDiff(left.data, right.data)
        myDiff.ravel()
        for i in xrange(myDiff.size):
            index = delinearIndex(i, left.dims)
            if isNaN(myDiff):
                nanIndices.append(index)
            elif myDiff > threshold:
                changeInfo.append((index, myDiff, left.data[i], right.data[i]))
        print "%d values changed between number and NaN" % len(nanIndices)
        print "%d values changed more than %.2f%%" % (len(changeInfo),
                                                    threshold)
        if shownandiffs and len(nanIndices) > 0:
            print "Indices changed between number and NaN:", nanIndices
        if len(changeInfo) > 0:
            print "%10s  %5s  %10s  %10s" % ("index", "%diff", "left", "right")
            for item in changeInfo:
                print "%10s  %5.2f  %10f  %10f" % item

if __name__ == "__main__":
    import optparse

    info = []
    info.append("This utility compares two files that are readable by the ")
    info.append("NeXus API.")

    info.append("In the difference the '<' or '-' symbol means that the ")
    info.append("field exists in the left file but not the right.")
    info.append("'>' or '+' symbol means that the field exists in the right ")
    info.append("file but not the left.")
    info.append("'|' symbol means that the field has changed between the ")
    info.append("two files.")

    parser = optparse.OptionParser("usage %prog [options] <left> <right>",
                                   None, optparse.Option, VERSION, 'error',
                                   " ".join(info))
    parser.add_option("-v", "--verbose", action="count", dest="verbose",
                      help="Enable verbose print statements", default=0)
    parser.add_option("-q", "--brief", action="store_true", dest="quiet",
                      help="Disable verbose print statements")
    parser.add_option("-u", "", action="store_true", dest="unifieddiff",
                      help="Use the unified output format.")
    parser.add_option("-s", "--report-identical-files", action="store_true",
                      dest="reportidenticalfiles",
                      help="Report when two files are the same.")
    parser.add_option("-S", "--suppress-common-lines", action="store_true",
                      dest="suppresscommon", help="Do not output common lines")
    parser.add_option("", "--show-values", action="store_true",
                      dest="showvalues",
                      help="Show the values of arrays that do not match")
    parser.add_option("", "--num-values", dest="numvalues",
                      help="Set the number of values to show in "
                      + "\"--show-values\" mode. If not specified the default "
                      + "is ten (10). To show all values specify minus one "
                      + "(-1).")
    parser.add_option("", "--show-percent", dest="threshold",
                      help="Set a threshold for the minimum percentage "
                      + "difference shown in values. This turns on "
                      + "\"--show-values\" mode and overrides "
                      + "\"--num-values\".")
    parser.add_option("", "--show-nan-diffs", dest="shownandiffs",
                      action="store_true",
                      help="Whether or not to show the indices that changed "
                      + "to nan when using \"--show-percent\" mode.")
    parser.add_option("-L", "--ignore-links", dest="ignorelinks",
                      action="store_true",
                      help="Only compare the original copy of links")
    parser.add_option("", "--ignore-note-meta", dest="ignorenotes",
                      action="store_true",
                      help="Ignore the author and date fields in notes")
    parser.add_option("", "--ignore-ts", dest="ignorets", action="store_true",
                      help="Ignore differences in the version of ts used")
    parser.add_option("", "--retype", dest="retype", action="append", help="Ignore "
                      + "differences after renaming. Please be consistent." +
                      "Usage: --retype leftname1=rightname1  --retype leftname2=rightname2 ...")
    parser.add_option("", "--compareFloat32Float64", dest="epsilon", action="store",
                      type="float", default=0.0000001, help = "Set the precision (float) for "
                      + "relative epsilon comparison. The default is 0.0000001.")
    parser.add_option("", "--ignore-run-number", dest="ignorerunnumber", action="store_true",
                                        help="Ignore differences of /entry/entry_identifier and "
                                        + "/entry/run_number.")
    parser.set_defaults(verbose=False)
    parser.set_defaults(ignorelinks=False)
    parser.set_defaults(numvalues=None)
    parser.set_defaults(threshold=None)
    parser.set_defaults(shownandiffs=False)
    parser.set_defaults(ignorenotes=False)
    parser.set_defaults(ignorets=False)
    parser.set_defaults(retype=[])
    parser.set_defaults(ignorerunnumber=False)
    # parse and fix values
    (options, args) = parser.parse_args()
    if options.quiet:
        options.verbose = -1
    if options.unifieddiff:
        format = "unified"
    else:
        format = "standard"
    if options.numvalues is not None:
        options.showvalues = True
        options.numvalues = int(options.numvalues)
    else:
        options.numvalues = 0
    if options.threshold is not None:
        options.showvalues = True
        options.threshold = float(options.threshold)
    if len(args) != 2:
        parser.error("Must compare two files")
    else:
        left = NXSfile(args[0], ignorelinks=options.ignorelinks,
                       ignorenotes=options.ignorenotes,
                       ignorets=options.ignorets,
                       retype=options.retype, epsilon=options.epsilon,
                       ignorerunnumber=options.ignorerunnumber)
        right = NXSfile(args[1], ignorelinks=options.ignorelinks,
                        ignorenotes=options.ignorenotes,
                        ignorets=options.ignorets,
                        retype=options.retype, epsilon=options.epsilon,
                        ignorerunnumber=options.ignorerunnumber)

    # direct comparison of the two files
    if left == right:
        if options.reportidenticalfiles:
            print "Files %s and %s are identical" % (left, right)
        sys.exit()

    # create the full diff
    diffs = left.cmpPaths(right)

    # determine if there were differences
    different = False
    for diff in diffs:
        if len(diff.summary) > 0:
            different = True

    # take the easy way out if quiet or identical
    if not different:
        if options.reportidenticalfiles:
            print "Files %s and %s are identical" % (left, right)
        sys.exit()
    elif options.verbose == -1:
        print "Files %s and %s differ" % (left, right)
        sys.exit()

    # remove common lines if requestsed
    if options.suppresscommon:
        indices = range(len(diffs))
        indices.reverse()
        for i in indices:
            if diffs[i].summary == Diff.SAME:
                del diffs[i]

    # reformat the diff
    for diff in diffs:
        diff.setFormat(format)

    # print out the result
    for diff in diffs:
        print diff
        if options.showvalues and diff.summary == Diff.DIFF:
            printDataDiff(left, right, diff, format=format,
                          numitems=options.numvalues,
                          threshold=options.threshold,
                          shownandiffs=options.shownandiffs)
