# Copyright 2019 VMware, Inc.
# All rights reserved. -- VMware Confidential

"""Master Boot Record (MBR) partitioning.

Definitions:

  Primary    A primary partition is a data partition that is defined in the main
             partition table (contained in the MBR). There can be at most 4
             primary partitions (or 3 if an extended partition is also present).
             Primary partitions are numbered from 1 to 4.

  Extended   An extended partition is a "container" partition whose purpose is
             to wrap logical partitions. There can be zero or one extended
             partition. The extended partition can be assigned any number from
             1 to 4.

  Logical    A logical partition is a data partition which is embedded within
             the extended partition. There may as many logical partitions as the
             disk size allows. Logical partition numbers start at 5 and grow
             incrementally with the number of logical partitions.

  MBR        The Master Boot Record (MBR) is the first sector of the disk, which
             contains the first-stage bootloader code, as well as the disk
             partition table. Note that MBR was deprecated in favor of the GUID
             Partition Table (GPT). However MBR is still used - sometimes in
             addition to GPT - to ensure backward compatibility with legacy OS
             and firmwares.

  EBR        Extended Boot Records (EBR) provide an extension to the MBR
             partition table which can only describe 4 primary partitions. Each
             EBR is contained within the extended partition, and is located
             before the logical partition that it describes. EBR's follow the
             same format as the MBR, but have at most 2 entries in the partition
             table: one to describe the logical partition, and another optional
             entry which serves as a pointer to the next EBR.
"""
from struct import pack, unpack, unpack_from

from systemStorage import *
from systemStorage.blockdev import CHS, Partition

SIZEOF_PARTITION_TABLE_ENTRY = 16
SIZEOF_MBR_BOOT_CODE = 446
MBR_SIGNATURE = bytes((0x55, 0xaa))

MBR_FS_TYPES = {0x04: FS_TYPE_FAT16,
                0x05: FS_TYPE_EXTENDED,
                0x06: FS_TYPE_VFAT,
                0x0b: FS_TYPE_FAT32,
                0x0c: FS_TYPE_FAT32,
                0x0e: FS_TYPE_FAT32,
                0xde: FS_TYPE_DELL_UTILITY,
                0xee: FS_TYPE_EFI_GPT,
                0xef: FS_TYPE_UEFI_SYSTEM,
                0xfb: FS_TYPE_VMFS,
                0xf8: FS_TYPE_VMFS_L,
                0xfc: FS_TYPE_VMKCORE}

def readBootSector(disk):
   """Read and return the first sector (boot code) from the disk.
   """
   try:
      disk.open(os.O_RDONLY)
      mbrSector = disk.readBlock(0, 1)
   finally:
      disk.close()
   return mbrSector

def updateBootSector(disk, bootSector):
   """Updates/Restore first sector (boot code) on the disk.

   MBR spec specifies that the first partition entry starts at offset 446
   (SIZEOF_MBR_BOOT_CODE), so restore original bytes (boot-code) that exists
   before that + updated partition table bytes.
   @bootSector: original boot code bytes
   """
   assert len(bootSector) > SIZEOF_MBR_BOOT_CODE, \
          ("invalid MBR boot code size: %s < %s" %
           (len(bootSector), SIZEOF_MBR_BOOT_CODE))
   disk.open(os.O_RDWR)
   try:
      mbr = disk.readBlock(0, 1)
      mbr = bootSector[0:SIZEOF_MBR_BOOT_CODE] + mbr[SIZEOF_MBR_BOOT_CODE:]
      disk.writeBlock(0, mbr)
   finally:
      disk.close()

class MbrPartition(object):
   """MBR partition table entry.
   """
   _FMT = ('<'  # Little-endian
           'B'  # Bootable flag
           '3s' # Start CHS
           'B'  # Partition type
           '3s' # End CHS
           'I'  # Start LBA
           'I') # Number of sectors in partition

   def __init__(self, fsType, start, end, bootable):
      self.fsType = fsType
      self.start = start
      self.end = end
      self.bootable = bootable

   @classmethod
   def fromGenericPart(cls, part):
      """Instanciate a new MbrPartition object from a generic Partition object.
      """
      return cls(part.fsType, part.start, part.end, part.bootable)

   @staticmethod
   def _unpackChs(buf):
      """Unpack a CHS value from an MBR partition table entry.
      """
      cylinder = ((buf[1] << 2) & 0x300) | buf[2]
      head = buf[0]
      sector = buf[1] & 0x3f
      return CHS(cylinder, head, sector)

   @staticmethod
   def _packChs(chs):
      """Convert a CHS address into the MBR partition entry format.
      """
      if chs.cylinder > 1023:
         # The sector address is too large to be addressed with a 3-byte CHS
         # value. Per MBR convention, use CHS(1023, 254, 63) to indicate to the
         # BIOS to use the LBA values instead.
         chs = CHS(1023, 254, 63)
      if chs.head > 255:
         raise OverflowError("C/H/S out of range (H == %u > 255)" % chs.head)
      if chs.sector > 63:
         raise OverflowError("C/H/S out of range (S == %u > 63)" % chs.sector)

      b0 = chs.head
      b1 = ((chs.cylinder & 0x300) >> 2) | chs.sector
      b2 = chs.cylinder & 0xff
      return bytes((b0, b1, b2))

   @classmethod
   def unpack(cls, buf, blockDev):
      """Unpack an MBR partition table entry from memory.
      """
      assert len(buf) >= SIZEOF_PARTITION_TABLE_ENTRY, \
         ("bad MBR partition table entry (size=%u < %u)" %
          (len(buf), SIZEOF_PARTITION_TABLE_ENTRY))

      pte = unpack(MbrPartition._FMT, buf[:SIZEOF_PARTITION_TABLE_ENTRY])
      flags, _, partType, _, startLba, sizeInLba = pte

      if partType == 0:
         # partition not present
         return None

      if flags == 0x80:
         bootable = True
      elif flags == 0x0:
         bootable = False
      else:
         # partition marked as invalid
         return None

      fsType = MBR_FS_TYPES.get(partType, FS_TYPE_UNKNOWN)

      if fsType == FS_TYPE_EFI_GPT:
         # Protective MBR, the partition table is GPT.
         raise ValueError('Protective MBR is seen, the partition table '
                          'appears to be GPT')

      return cls(fsType, startLba, startLba + sizeInLba - 1, bootable)

   def pack(self, blockDev):
      """Construct an MBR partition table entry.
      """
      flags = 0x80 if self.bootable else 0x0

      startChs = blockDev.lbaToChs(self.start)
      startChs = self._packChs(startChs)
      endChs = blockDev.lbaToChs(self.end)
      endChs = self._packChs(endChs)

      for fsId, fsType in MBR_FS_TYPES.items():
         if self.fsType == fsType:
            return pack(MbrPartition._FMT, flags, startChs, fsId, endChs,
                        self.start, self.end - self.start + 1)

      raise NotImplementedError("%s: filesystem is not supported" % self.fsType)


class Mbr(object):
   """Class that represents a Master Boot Record (MBR).
   """
   _SIZEOF_BOOT_CODE = 440

   _FMT = ('<'     # Little-endian
           '440s'  # Boot Code
           'I'     # Unique disk signature
           '2x'    # Reserved, set to zero
           '64s'   # Partition array
           '2s')   # MBR signature

   def __init__(self, bootCode=None, uid=None):
      if bootCode is None:
         self._bootCode = bytes(self._SIZEOF_BOOT_CODE)
      else:
         self._bootCode = bootCode
      self._uid = 0 if uid is None else uid
      self.partitions = {}

   @classmethod
   def unpack(cls, buf, blockDev, maxPartitionsNum=4):
      """Unpack an MBR from memory.
      """
      SIZEOF_MBR = 512
      assert len(buf) >= SIZEOF_MBR, ("bad MBR block (size=%u < %u)" %
                                      (len(buf), SIZEOF_MBR))

      uid, _, partArray, signature = unpack_from(cls._FMT, buf)
      if signature != MBR_SIGNATURE:
         raise ValueError("not an MBR")

      mbr = cls(None, uid)

      partitions = {}
      for i in range(maxPartitionsNum):
         offset = i * SIZEOF_PARTITION_TABLE_ENTRY
         part = MbrPartition.unpack(partArray[offset:], blockDev)
         if part is not None:
            mbr.partitions[i + 1] = part

      return mbr

   def pack(self, blockDev, partitions):
      """Construct the MBR containing the primary partition table.
      """
      partitionTable = bytes()
      for i in range(4):
         if i + 1 in partitions:
            part = MbrPartition.fromGenericPart(partitions[i + 1])
            partitionTable += part.pack(blockDev)
         else:
            partitionTable += bytes(SIZEOF_PARTITION_TABLE_ENTRY)

      return pack(self._FMT, self._bootCode, self._uid, partitionTable,
                  MBR_SIGNATURE)

class Ebr(Mbr):
   """Class that represents an Extended Boot Record (EBR).
   """

   @classmethod
   def unpack(cls, buf, blockDev, extOffset, ebrOffset):
      """Unpack an EBR from memory.
      """
      EBR_NUM_PARTITION_ENTRIES = 2  # Per MBR spec
      mbr = Mbr.unpack(buf, blockDev,
                       maxPartitionsNum=EBR_NUM_PARTITION_ENTRIES)

      for partNum, part in mbr.partitions.items():
         if part.fsType != FS_TYPE_EXTENDED:
            # Logical partition start is relative to this EBR start LBA
            part.start += ebrOffset
            part.end += ebrOffset
         else:
            # Next EBR LBA is relative to main extended partition
            part.start += extOffset
            part.end += extOffset

      return mbr

   def pack(self, blockDev, logicalPart, nextLogicalPart, extOffset, ebrOffset):
      """Construct a logical partition EBR.

      Each EBR is placed in the sector immediately preceding the partition which
      it describes (EBR_LBA = LOGICAL_PART_START_LBA - 1). This allows users to
      specify a partition location and size, and be assured that the partition
      will actually start at the specified offset, and be of the specified size.
      This is particularly important to enforce proper partition alignment, or
      in case existing partitions must be preserved during re-partitioning.
      However this also assumes that the user can describe the partition table
      accounting for at least one free sector before each logical partition.
      """
      part = MbrPartition.fromGenericPart(logicalPart)
      part.start = 1
      part.end -= ebrOffset
      partitionTable = part.pack(blockDev)

      if nextLogicalPart is not None:
         nextPart = MbrPartition.fromGenericPart(nextLogicalPart)
         nextPart.fsType = FS_TYPE_EXTENDED
         nextPart.start -= extOffset + 1
         nextPart.end -= extOffset
         partitionTable += nextPart.pack(blockDev)

      numEmptyEntries = 2 if nextLogicalPart is None else 3
      partitionTable += bytes(numEmptyEntries * SIZEOF_PARTITION_TABLE_ENTRY)

      return pack(self._FMT, self._bootCode, self._uid, partitionTable,
                  MBR_SIGNATURE)


class MbrPartitionTable(object):
   """Class that represents a disk MBR partition table.
   """

   def __init__(self):
      self.partitions = {}

   def setPartition(self, partNum, fsType, start, end, bootable=False):
      """Add/modify a partition entry in the MBR partition table.
      """
      part = Partition(partNum, fsType, start, end, None, bootable=bootable)
      self.partitions[partNum] = part

      if bootable:
         self._bootPart = partNum

   def scan(self, blockDev):
      """Read a block device's MBR partition table.
      """
      MBR_LOGICAL_BASE_ID = 5

      lba0 = blockDev.readBlock(0, 1)
      mbr = Mbr.unpack(lba0, blockDev)

      logicalId = MBR_LOGICAL_BASE_ID

      for primaryId, part in mbr.partitions.items():
         self.setPartition(primaryId, part.fsType, part.start, part.end,
                           bootable=part.bootable)

         if part.fsType == FS_TYPE_EXTENDED:
            offset = part.start

            while offset is not None:
               lba = blockDev.readBlock(offset, 1)
               ebr = Ebr.unpack(lba, blockDev, part.start, offset)

               for logicalPart in ebr.partitions.values():
                  if logicalPart.fsType != FS_TYPE_EXTENDED:
                     self.setPartition(logicalId, logicalPart.fsType,
                                       logicalPart.start, logicalPart.end,
                                       bootable=logicalPart.bootable)
                     logicalId += 1
                     offset = None
                  else:
                     offset = logicalPart.start
                     break

   def sync(self, blockDev, bootCode=None):
      """Write the partition table to disk.
      """
      logicals = [(n, p) for n, p in self.partitions.items() if n > 4]

      # Create primary partitions MBR, add extended partition if applicable
      if logicals:
         logicals = sorted(logicals, key=lambda x: x[1].start)
         extStart = logicals[0][1].start - 1
         extEnd = logicals[-1][1].end

         for partNum in self.partitions:
            if self.partitions[partNum].fsType == FS_TYPE_EXTENDED:
               break
         else:
            for i in range(1, 5):
               if i not in self.partitions:
                  self.setPartition(i, FS_TYPE_EXTENDED, extStart, extEnd)
                  break

      mbr = Mbr(bootCode, 0)
      lba0 = mbr.pack(blockDev, self.partitions)

      # Create logical partitions EBRs, if applicable
      ebrBlocks = []
      if logicals:
         for i, (_, part) in enumerate(logicals):
            try:
               _, nextPart = logicals[i + 1]
            except IndexError:
               nextPart = None

            ebr = Ebr()
            ebrOffset = part.start - 1
            block = ebr.pack(blockDev, part, nextPart, extStart, ebrOffset)
            ebrBlocks += [(ebrOffset, block)]

      # Write partition table to disk
      blockDev.writeBlock(0, lba0)
      for lba, ebr in ebrBlocks:
         blockDev.writeBlock(lba, ebr)
