#!/usr/bin/env python

import sys
from electrum import Interface
from electrum import bitcoin, Transaction, Network



def get_transaction(network, tx_hash, tx_height):
    raw_tx = network.synchronous_get([(
        'blockchain.transaction.get', [tx_hash, tx_height])])[0]
    tx = Transaction(raw_tx)
    return tx


def get_history(network, addr):
    transactions = network.synchronous_get([(
        'blockchain.address.get_history', [addr])])[0]
    transactions.sort(key=lambda x: x["height"])
    return [(i["tx_hash"], i["height"]) for i in transactions]


def get_addr_balance(network, address):
    prevout_values = {}
    h = get_history(network, address)
    if h == ['*']:
        return 0, 0
    c = u = 0
    received_coins = []   # list of coins received at address
    transactions = {}

    # fetch transactions
    for tx_hash, tx_height in h:
        transactions[(tx_hash, tx_height)] = get_transaction(
            network, tx_hash, tx_height)

    for tx_hash, tx_height in h:
        tx = transactions[(tx_hash, tx_height)]
        if not tx:
            continue
        update_tx_outputs(tx, prevout_values)
        for i, (addr, value) in enumerate(tx.outputs):
            if addr == address:
                key = tx_hash + ':%d' % i
                received_coins.append(key)

    for tx_hash, tx_height in h:
        tx = transactions[(tx_hash, tx_height)]
        if not tx:
            continue
        v = 0

        for item in tx.inputs:
            addr = item.get('address')
            if addr == address:
                key = item['prevout_hash'] + ':%d' % item['prevout_n']
                value = prevout_values.get(key)
                if key in received_coins:
                    v -= value
        for i, (addr, value) in enumerate(tx.outputs):
            key = tx_hash + ':%d' % i
            if addr == address:
                v += value
        if tx_height:
            c += v
        else:
            u += v
    return c, u


def update_tx_outputs(tx, prevout_values):
    for i, (addr, value) in enumerate(tx.outputs):
        key = tx.hash() + ':%d' % i
        prevout_values[key] = value


def main(address):
    network = Network()
    network.start(wait=True)
    c, u = get_addr_balance(network, address)

    print("Balance - confirmed: %d (%.8f BTC), unconfirmed: %d (%.8f BTC)" %
          (c, c / 100000000., u, u / 100000000.))

if __name__ == "__main__":
    try:
        address = sys.argv[1]
    except Exception:
        print "usage: get_balance <bitcoin_address>"
        sys.exit(1)
    main(address)
