#  fsh - fast remote execution
#  Copyright (C) 1999-2001 by Per Cederqvist.
#
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 2 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License
#  along with this program; if not, write to the Free Software
#  Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */

import errno
import getopt
import getpass
import os
import select
import socket
import string
import sys
import time

import fshlib

class remote:
    def __init__(self, method, login, use_l_flag, server, command,
                 fshd_timeout):
        (sockname, unused_dir) = fshlib.fshd_socket(server, method, login)
        self.s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
	try:
	    self.s.connect(sockname)
	except socket.error, (eno, emsg):
	    if eno != errno.ECONNREFUSED and eno != errno.ENOENT:
		raise
	    else:
		# Start an fshd in daemon mode.
		cmd = "fshd -b -r " + fshlib.shell_quote(method)
		if use_l_flag:
		    cmd = cmd + " -l " + login
                if fshd_timeout != None:
                    cmd = cmd + " --timeout='%s'" % fshd_timeout
		cmd = cmd + " " + server
		ret = os.system(cmd)
		if not (os.WIFEXITED(ret) and os.WEXITSTATUS(ret) == 0):
		    sys.stderr.write("fsh: Failed to start fshd\n")
		    sys.exit(1)

		# Repeatedly try to connect to it.  Giv up after a minute.
		for i in range(60):
		    time.sleep(1)
		    try:
			self.s.connect(sockname)
			break
		    except socket.error, (eno, emsg):
			if eno != errno.ECONNREFUSED and eno != errno.ENOENT:
			    raise
		else:
		    sys.stderr.write("fsh: Failed to connect to tunnel.\n")
		    sys.exit(1)
			
	self.s.setblocking(0)
        self.sock = self.s.fileno()
        self.send_queue = []
        self.receive_queue = []
	self.eos_seen = 0
        self.send("new %s\n" % fshlib.hollerith(command))
        
    def select_set(self):
        w = []
        if self.sock == -1:
            return [], [], []
        if self.send_queue != []:
            w.append(self.sock)
        return [self.sock], w, []

    def select_action(self, r, w, e):
        if self.sock in r:
	    if fshlib.read(self.sock, self.receive_queue, 4096) == -1:
                self.s.close()
                self.sock = -1

        if self.sock in w:
	    if fshlib.write(self.sock, self.send_queue) == -1:
		self.s.close()
		self.sock = -1

    def send(self, data):
	#os.write(debug, "enqueueing this: %s.\n" % data)
        self.send_queue.append(data)
        if self.sock == -1:
            self.send_queue = []

    def poll_response(self):
        [cmd, data] = fshlib.parse_line(self.receive_queue, 0)

	# If half a command got through when we lost the tunnel, remove it.
	if cmd == None and self.sock == -1:
	    self.receive_queue = []

	if cmd == "eos":
	    self.eos_seen = 1

        return [cmd, data]

    def close(self):
        self.s.close()
        self.sock = -1

    def drained(self):
        return self.receive_queue == [] and self.sock == -1

class local_side:
    def __init__(self):
        self.stdin_quota = fshlib.QUOTA
        self.stdout_quota = fshlib.QUOTA
        self.stderr_quota = fshlib.QUOTA
        self.stdin_counter = 0
        self.stdout_counter = 0
        self.stderr_counter = 0
        self.stdout_queue = []
        self.stderr_queue = []
        self.stdin_fd = 0
        self.stdout_fd = 1
        self.stderr_fd = 2
	fshlib.set_nonblocking(self.stdin_fd)
	fshlib.set_nonblocking(self.stdout_fd)
	fshlib.set_nonblocking(self.stderr_fd)
        self.epipe_seen = 0
	self.stdout_pending_close = 0
	self.stderr_pending_close = 0

    def select_set(self):
        r = []
        w = []
        if self.stdin_fd != -1 and self.stdin_counter < self.stdin_quota:
            r.append(self.stdin_fd)
        if self.stdout_fd != -1 and self.stdout_queue != []:
            w.append(self.stdout_fd)
        if self.stderr_fd != -1 and self.stderr_queue != []:
            w.append(self.stderr_fd)
        return r, w, []

    def select_action(self, r, w, e, srv):
        if self.stdin_fd in r:
            wanted = min(4096, self.stdin_quota - self.stdin_counter)
            assert(wanted > 0)
	    queue = []
	    if fshlib.read(self.stdin_fd, queue, wanted) == -1:
                os.close(self.stdin_fd)
                self.stdin_fd = -1
                srv.send("eof-stdin\n")
	    elif len(queue) > 0:
		assert(len(queue) == 1)
                srv.send("stdin %s\n" % fshlib.hollerith(queue[0]))
                self.stdin_counter = self.stdin_counter + len(queue[0])

        if self.stdout_fd in w:
	    sz = fshlib.write(self.stdout_fd, self.stdout_queue)
	    if sz == -1:
		srv.send("eof-stdout\n")
		os.close(self.stdout_fd)
		self.stdout_fd = -1
	        #os.write(debug, "epipe seen I\n")
		self.epipe_seen = 1
	    else:
		self.stdout_counter = self.stdout_counter + sz

	    #os.write(debug, "successful write of %d bytes.\n" % sz)
            if self.stdout_pending_close and self.stdout_queue == []:
		#os.write(debug, "closing stdout due to pend.\n")
                os.close(self.stdout_fd)
                self.stdout_fd = -1
	    #os.write(debug, "hmm.\n")
            if self.stdout_fd != -1 and \
               (self.stdout_quota - self.stdout_counter) < fshlib.QUOTA/2:

		#os.write(debug, "increasing quota.\n")
                self.stdout_quota = self.stdout_quota + fshlib.QUOTA
                srv.send("stdout-flow %s\n" % fshlib.hollerith(
                    self.stdout_quota))
		#os.write(debug, "increased quota.\n")

        if self.stderr_fd in w:
	    sz = fshlib.write(self.stderr_fd, self.stderr_queue)
	    if sz == -1:
		srv.send("eof-stderr\n")
		os.close(self.stderr_fd)
		self.stderr_fd = -1
		#os.write(debug, "epipe seen II\n")
		self.epipe_seen = 1
	    else:
		self.stderr_counter = self.stderr_counter + sz
            self.stderr_counter = self.stderr_counter + sz
            if self.stderr_pending_close and self.stderr_queue == []:
                os.close(self.stderr_fd)
                self.stderr_fd = -1
            if self.stderr_fd != -1 and \
               (self.stderr_quota - self.stderr_counter) < fshlib.QUOTA/2:

                self.stderr_quota = self.stderr_quota + fshlib.QUOTA
                srv.send("stderr-flow %s\n" % fshlib.hollerith(
                    self.stderr_quota))

    def stdout_send(self, data):
	if data != "":
	    self.stdout_queue.append(data)

    def stderr_send(self, data):
	if data != "":
	    self.stderr_queue.append(data)

    def eof_stdin(self):
        if self.stdin_fd != -1:
            os.close(self.stdin_fd)
            self.stdin_fd = -1

    def eof_stdout(self):
        self.stdout_pending_close = 1
	if self.stdout_queue == [] and self.stdout_fd != -1:
	    os.close(self.stdout_fd)
	    self.stdout_fd = -1

    def eof_stderr(self):
	#os.write(debug, "eof_stderr: start\n")
        self.stderr_pending_close = 1
	#os.write(debug, "eof_stderr: I\n")
	if self.stderr_queue == [] and self.stderr_fd != -1:
	    #os.write(debug, "eof_stderr: II (%d, %d)\n" % (debug, self.stderr_fd))
	    os.close(self.stderr_fd)
	    #os.write(debug, "eof_stderr: III\n")
	    self.stderr_fd = -1
	    #os.write(debug, "eof_stderr: IV\n")
	#os.write(debug, "eof_stderr: V\n")

    def stdin_flow(self, data):
        self.stdin_quota = data

    def drained(self):
        return self.stdout_fd == -1 and self.stderr_fd == -1

def docmd(method, login, use_l_flag, server, command, fshd_timeout):
    rem = remote(method, login, use_l_flag, server, command, fshd_timeout)
    loc = local_side()
    exitval = None
    tunnel_lost = 0
    while 1:
	#os.write(debug, "at top of main loop\n")
        while 1:
            cmd, data = rem.poll_response()
            if cmd == None:
                break
	    #os.write(debug, "got cmd: %s (%s)\n" % (cmd, data))
            if cmd == "stdout":
                loc.stdout_send(data)
            elif cmd == "stderr":
                loc.stderr_send(data)
            elif cmd == "eof-stdin":
                loc.eof_stdin()
            elif cmd == "eof-stdout":
                loc.eof_stdout()
            elif cmd == "eof-stderr":
                loc.eof_stderr()
            elif cmd == "exit":
                exitval = string.atoi(data)
            elif cmd == "signal-exit":
                # What do we do here?
                exitval = string.atoi(data)
            elif cmd == "eos":
		#os.write(debug, "eos: start\n")
                loc.eof_stdin()
		#os.write(debug, "eos: stdin closed\n")
                loc.eof_stdout()
		#os.write(debug, "eos: stdout closed\n")
                loc.eof_stderr()
		#os.write(debug, "eos: stderr closed\n")
                rem.close()
		#os.write(debug, "eos: rem closed\n")
            elif cmd == "stdin-flow":
                loc.stdin_flow(string.atoi(data))
            else:
                raise "syntax error from server:" + str(cmd) + ":" + str(data)

	#os.write(debug, "main loop: commands performed\n")

        if loc.drained() and rem.drained():
            if exitval != None:
                if loc.epipe_seen and exitval == 0:
		    #os.write(debug, "epipe apparently seen\n")
                    sys.exit(1)
                else:
		    #os.write(debug, "exiting with %d\n" % exitval)
                    sys.exit(exitval)
            else:
                raise "don't know what to exit with"

	if rem.drained() and not rem.eos_seen and not tunnel_lost:
	    # We lost contact in mid-session.
	    # Schedule a shutdown of the client, but allow any pending
	    # data to be transmitted.
	    loc.eof_stdin()
	    loc.eof_stdout()
	    loc.stderr_send("Connection to fsh tunnel lost.\n")
	    loc.eof_stderr()
	    exitval = 1
	    tunnel_lost = 1

        r0, w0, e0 = rem.select_set()
        r1, w1, e1 = loc.select_set()
	#os.write(debug, "main loop: entering select\n")
        r, w, e = select.select(r0 + r1, w0 + w1, e0 + e1, 30)
	#os.write(debug, "main loop: dealing with remote select\n")
        rem.select_action(r, w, e)
	#os.write(debug, "main loop: dealing with local select\n")
        loc.select_action(r, w, e, rem)
	#os.write(debug, "main loop: end of loop\n")
        
def usage(ret):
    sys.stderr.write(
	"fsh: usage: fsh [options] host command [args...]\n"
	"            fsh { -h | --help }\n"
	"            fsh { -V | --version }\n")
    if ret == 0:
        sys.stderr.write(
            "Options:\n"
            "  -r method             Use ``method'' (e.g. ``rsh'') instead "
            "of ssh.\n"
            "  -l login              Log in as user ``login''.\n"
            "  -T --timeout=timeout  Set idle timeout for fshd (in seconds); exit when "
            "no\n                        session has existed for this long. "
            "0 disables.\n")
    sys.exit(ret)

def main():
    method = "ssh"
    login = None
    use_l_flag = 0
    print_version = 0
    fshd_timeout = None

    # First, look for options before the host.
    try:
	opts, args = getopt.getopt(sys.argv[1:], "hr:l:VT:",
                                   ["version", "help", "timeout="])
    except getopt.error, msg:
	sys.stderr.write(str(msg) + "\n")
	sys.exit(1)

    # For historical reasons, we have to look for options after the
    # host name as well.  This is ugly, but this syntax is widely
    # used, and we aim for rsh compatibility.  We only allow the
    # "-l login" after the host name.
    if len(args) < 2:
        host = None
        opts2 = []
    else:
        try:
            host = args[0]
            opts2, args = getopt.getopt(args[1:], "l:")
        except getopt.error, msg:
            sys.stderr.write(str(msg) + "\n")
            sys.exit(1)
    

    for opt, val in opts + opts2:
	if opt == "-r":
	    method = val
	elif opt == "-l":
	    login = val
	    use_l_flag = 1
	elif opt == "-V" or opt == "--version":
	    print_version = 1
	elif opt == "-h" or opt == "--help":
	    usage(0)
        elif opt == "-T" or opt == "--timeout":
            fshd_timeout = val
        else:
            assert 0

    if print_version:
	fshlib.print_version("fsh")
    if len(args) < 1 or host == None:
	usage(1)
    if login == None:
	login = getpass.getuser()
    docmd(method, login, use_l_flag, host, string.join(args),
          fshd_timeout)

if __name__ == "__main__":
    #global debug
    #debug = os.dup(2)
    #debug = os.open("/dev/null", os.O_RDWR)
    #os.write(debug", "debug: writing to fd %d.\n" % debug)
    main()
