import sys
import multiprocessing
import pybedtools

gff = pybedtools.example_filename('gdc.gff')
bam = pybedtools.example_filename('gdc.bam')

g = pybedtools.BedTool(gff).remove_invalid().saveas()


def featuretype_filter(feature, featuretype):
    if feature[2] == featuretype:
        return True
    return False


def subset_featuretypes(featuretype):
    result = g.filter(featuretype_filter, featuretype).saveas()
    return pybedtools.BedTool(result.fn)


def count_reads_in_features(features_fn):
    """
    Callback function to count reads in features
    """

    return pybedtools.BedTool(bam).intersect(
                             b=features_fn,
                             stream=True).count()


pool = multiprocessing.Pool()

featuretypes = ('intron', 'exon')
introns, exons = pool.map(subset_featuretypes, featuretypes)

exon_only = exons.subtract(introns).merge().remove_invalid().saveas().fn
intron_only = introns.subtract(exons).merge().remove_invalid().saveas().fn
intron_and_exon = exons.intersect(introns).merge().remove_invalid().saveas().fn

features = (exon_only, intron_only, intron_and_exon)
results = pool.map(count_reads_in_features, features)

labels = ('      exon only:',
          '    intron only:',
          'intron and exon:')

for label, reads in zip(labels, results):
    sys.stdout.write('%s %s\n' % (label, reads))
