#!/usr/bin/python3 -BbbEIsSttW all
# This software is provided by the copyright owner "as is"
# and WITHOUT ANY EXPRESSED OR IMPLIED WARRANTIES, including,
# but not limited to, the implied warranties of merchantability
# and fitness for a particular purpose are disclaimed. In no
# event shall the copyright owner be liable for any direct,
# indirect, incidential, special, exemplary or consequential
# damages, including, but not limited to, procurement of substitute
# goods or services, loss of use, data or profits or business
# interruption, however caused and on any theory of liability,
# whether in contract, strict liability, or tort, including
# negligence or otherwise, arising in any way out of the use
# of this software, even if advised of the possibility of such
# damage.
#
# Copyright (c) 2021 Unparalleled IT Services e.U.
#
# Permission to use, copy, modify, and distribute this software
# according to GNU Lesser General Public License (LGPL-3.0)
# purpose is hereby granted, provided that the above copyright
# notice and this permission notice appear in all copies.

# This tool will listen on address 127.0.0.1:6001 for incoming
# connections and forward them to 127.0.0.1:6000. As some modern
# extensions, e.g. SHM, make interception harder, only whitelisted
# extensions are reported, see "QueryExtension" section below.


import errno
import os
import select
import socket
import sys

class ConnectionContext():
  def __init__(self, clientSocket, serverSocket):
    self.clientSocket = clientSocket
    self.serverSocket = serverSocket
    self.clientReadBuffer = b''
    self.clientWriteBuffer = b''
    self.serverReadBuffer = b''
    self.serverWriteBuffer = b''
    self.clientState = 0
    self.serverState = 0

  def checkHandleRead(self, readFd, readFds, writeFds):
    sendSocket = None
    oldSendState = None
    newSendState = None
    if readFd == self.clientSocket.fileno():
      try:
        inputData = os.read(readFd, 1<<20)
        if not inputData:
          raise OSError(0, 'EOF')
        sendSocket = self.serverSocket
        oldSendState = bool(self.serverWriteBuffer)
        self.clientReadBuffer += inputData
        self.handleClientRead()
        newSendState = bool(self.serverWriteBuffer)
      except OSError as readError:
# EOF reached.
        if readError.errno not in [0, errno.ECONNRESET]:
          raise
        readFds.remove(readFd)
        if self.clientReadBuffer:
          print('Incomplete data in read buffer.', file=sys.stderr)
        self.clientReadBuffer = None
    elif readFd == self.serverSocket.fileno():
      try:
        inputData = os.read(readFd, 1<<20)
        sendSocket = self.clientSocket
        oldSendState = bool(self.clientWriteBuffer)
        self.serverReadBuffer += inputData
        self.handleServerRead()
        newSendState = bool(self.clientWriteBuffer)
      except OSError as readError:
# EOF reached.
        if readError.errno not in [errno.ECONNRESET]:
          raise
        readFds.remove(readFd)
        if self.serverReadBuffer:
          print('Incomplete data in read buffer.', file=sys.stderr)
        self.serverReadBuffer = None
    else:
      return False
    if oldSendState != newSendState:
      writeFds.append(sendSocket.fileno())
    return True


  def handleClientRead(self):
    if self.clientState == 0:
      if self.clientReadBuffer != b'l\x00\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00':
        raise Exception()
      self.clientState = 1
      self.serverWriteBuffer += self.clientReadBuffer
      self.clientReadBuffer = b''
      return

    while len(self.clientReadBuffer) >= 4:
      length = int.from_bytes(self.clientReadBuffer[2:4], 'little') * 4
      if length == 0:
        if len(self.clientReadBuffer) < 8:
          break
        length = int.from_bytes(self.clientReadBuffer[4:8], 'little') * 4
        if (length < 8) or (length > (1<<24)):
          raise Exception()
      if len(self.clientReadBuffer) < length:
# Print those errors, they may indicate loss of frame.
        print(
            'Insufficient client data, got %d, want %d' % (
                len(self.clientReadBuffer), length), file=sys.stderr)
        break
      request = self.clientReadBuffer[:length]
      self.clientReadBuffer = self.clientReadBuffer[length:]
      opcode = request[0]
      if opcode == 85:
# AllocNamedClolor
        print(
            'AllocNamedColor %s' % repr(request[12:32]),
            file=sys.stderr)
      elif opcode == 92:
# LookupColor (92)
        if length < 16:
          raise Exception(repr(request))
        strLength = int.from_bytes(request[8:10], 'little')
        if strLength + 12 > length:
          raise Exception()
        print(
            'LookupColor %s' % repr(request[12:12+strLength]),
            file=sys.stderr)
      elif opcode == 98:
# QueryExtension (98)
        if length < 12:
          raise Exception(repr(request[:length]))
        nameLength = int.from_bytes(request[4:6], 'little')
        extName = request[8:8+nameLength]
        print('Extension query %s' % repr(extName), file=sys.stderr)
        if extName not in [b'BIG-REQUESTS', b'XKEYBOARD']:
          print('Disabling extension request.', file=sys.stderr)
          request = b'\x62\x00\x03\x00\x04\x00\x00\x00\x58\x58\x58\x58'
      self.serverWriteBuffer += request

  def checkHandleWrite(self, writeFd, readFds, writeFds):
    newSendState = False
    if writeFd == self.clientSocket.fileno():
      try:
        result = os.write(writeFd, self.clientWriteBuffer)
        self.clientWriteBuffer = self.clientWriteBuffer[result:]
        newSendState = bool(self.clientWriteBuffer)
      except OSError as writeError:
# EOF reached.
        if writeError.errno not in [errno.EPIPE]:
          raise
        if self.clientWriteBuffer:
          print(
              'Write failed, discarding buffer %s' % (
                  repr(self.clientWriteBuffer)),
              file=sys.stderr)
          self.clientWriteBuffer = b''
    elif writeFd == self.serverSocket.fileno():
      result = os.write(writeFd, self.serverWriteBuffer)
      if (result <= 0):
        raise Exception()
      self.serverWriteBuffer = self.serverWriteBuffer[result:]
      newSendState = bool(self.serverWriteBuffer)
    else:
      return False

    if not newSendState:
      writeFds.remove(writeFd)
    return True


  def handleServerRead(self):
    if self.serverState == 0:
      if len(self.serverReadBuffer) < 8:
        return
      if self.serverReadBuffer[0:6] != b'\x01\x00\x0b\x00\x00\x00':
        raise Exception()
      length = int.from_bytes(self.serverReadBuffer[6:8], 'little') * 4 + 8
      if length > len(self.serverReadBuffer):
        return
      self.serverState = 1
      self.clientWriteBuffer += self.serverReadBuffer[0:length]
      self.serverReadBuffer = self.serverReadBuffer[length:]
# Fall through to state 1.
    while len(self.serverReadBuffer) >= 32:
      length = None
      replyCode = self.serverReadBuffer[0]
      if replyCode == 0:
# An error reply.
        length = 32
      elif replyCode == 1:
# A normal reply, length is 32 bit.
        length = int.from_bytes(self.serverReadBuffer[4:8], 'little') * 4 + 32
      elif replyCode < 128:
        length = 32
      else:
# An extension error reply.
        length = 32

      if len(self.serverReadBuffer) < length:
        break
      self.clientWriteBuffer += self.serverReadBuffer[0:length]
      self.serverReadBuffer = self.serverReadBuffer[length:]

  def isAlive(self):
    return True


class ApplicationContext():
  def __init__(self):
    self.listenSocket = None
    self.targetAddress = None
    self.connectionList = []

  def createTcpListenSocket(self, port):
    if self.listenSocket is not None:
      raise Exception()
    self.listenSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    self.listenSocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    self.listenSocket.bind(('127.0.0.1', port))
    self.listenSocket.listen(16)

  def setTargetAddress(self, targetHost, targetPort):
    self.targetAddress = (targetHost, targetPort)


  def run(self):
    readFds = [self.listenSocket.fileno()]
    writeFds = []
    exceptFds = []

    while True:
      readSelectFds, writeSelectFds, excpSelectFds = select.select(
        readFds, writeFds, exceptFds)

      for readFd in readSelectFds:
        if readFd == self.listenSocket.fileno():
          (clientSocket, remoteAddress) = self.listenSocket.accept()
          clientSocket.setblocking(False)
          readFds.append(clientSocket.fileno())

          serverSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
          serverSocket.connect(self.targetAddress)
          serverSocket.setblocking(False)
          readFds.append(serverSocket.fileno())

          self.connectionList.append(
              ConnectionContext(clientSocket, serverSocket))
          continue

        connectionFoundFlag = False
        for connection in self.connectionList:
          if connection.checkHandleRead(readFd, readFds, writeFds):
            if not connection.isAlive():
              self.connectionList.remove(connection)
            connectionFoundFlag = True
            break
        if not connectionFoundFlag:
          raise Exception()

      for writeFd in writeSelectFds:
        connectionFoundFlag = False
        for connection in self.connectionList:
          if connection.checkHandleWrite(writeFd, readFds, writeFds):
            if not connection.isAlive():
              self.connectionList.remove(connection)
            connectionFoundFlag = True
            break
        if not connectionFoundFlag:
          raise Exception()

      if excpSelectFds:
        raise Exception()


def main():
  """This is the program main function."""
  context = ApplicationContext()
  context.createTcpListenSocket(6001)
  context.setTargetAddress('127.0.0.1', 6000)

  context.run()

if __name__ == '__main__':
  main()
