Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 2153183a authored by xunchang's avatar xunchang
Browse files

blockimgdiff: Factor out the diff_worker

We will call it at an earlier time to compute the patch size; and
choose the transfers to convert to 'new'.

Bug: 120561199
Test: Generate an incremental update on shiner
Change-Id: I29a0c8e75c9e5b66a266c1387186692a86fcbe43
parent 19e65028
Loading
Loading
Loading
Loading
+151 −100
Original line number Diff line number Diff line
@@ -26,7 +26,8 @@ import os.path
import re
import sys
import threading
from collections import deque, OrderedDict
import zlib
from collections import deque, namedtuple, OrderedDict
from hashlib import sha1

import common
@@ -36,8 +37,12 @@ __all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]

logger = logging.getLogger(__name__)

# The tuple contains the style and bytes of a bsdiff|imgdiff patch.
PatchInfo = namedtuple("PatchInfo", ["imgdiff", "content"])


def compute_patch(srcfile, tgtfile, imgdiff=False):
  """Calls bsdiff|imgdiff to compute the patch data, returns a PatchInfo."""
  patchfile = common.MakeTempFile(prefix='patch-')

  cmd = ['imgdiff', '-z'] if imgdiff else ['bsdiff']
@@ -52,7 +57,7 @@ def compute_patch(srcfile, tgtfile, imgdiff=False):
    raise ValueError(output)

  with open(patchfile, 'rb') as f:
    return f.read()
    return PatchInfo(imgdiff, f.read())


class Image(object):
@@ -203,17 +208,17 @@ class Transfer(object):
    self.id = len(by_id)
    by_id.append(self)

    self._patch = None
    self._patch_info = None

  @property
  def patch(self):
    return self._patch
  def patch_info(self):
    return self._patch_info

  @patch.setter
  def patch(self, patch):
    if patch:
  @patch_info.setter
  def patch_info(self, info):
    if info:
      assert self.style == "diff"
    self._patch = patch
    self._patch_info = info

  def NetStashChange(self):
    return (sum(sr.size() for (_, sr) in self.stash_before) -
@@ -224,7 +229,7 @@ class Transfer(object):
    self.use_stash = []
    self.style = "new"
    self.src_ranges = RangeSet()
    self.patch = None
    self.patch_info = None

  def __str__(self):
    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
@@ -462,16 +467,7 @@ class BlockImageDiff(object):
    self.AbbreviateSourceNames()
    self.FindTransfers()

    # Find the ordering dependencies among transfers (this is O(n^2)
    # in the number of transfers).
    self.GenerateDigraph()
    # Find a sequence of transfers that satisfies as many ordering
    # dependencies as possible (heuristically).
    self.FindVertexSequence()
    # Fix up the ordering dependencies that the sequence didn't
    # satisfy.
    self.ReverseBackwardEdges()
    self.ImproveVertexSequence()
    self.FindSequenceForTransfers()

    # Ensure the runtime stash size is under the limit.
    if common.OPTIONS.cache_size is not None:
@@ -829,7 +825,7 @@ class BlockImageDiff(object):
            # These are identical; we don't need to generate a patch,
            # just issue copy commands on the device.
            xf.style = "move"
            xf.patch = None
            xf.patch_info = None
            tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
            if xf.src_ranges != xf.tgt_ranges:
              logger.info(
@@ -839,11 +835,10 @@ class BlockImageDiff(object):
                      xf.tgt_name + " (from " + xf.src_name + ")"),
                  str(xf.tgt_ranges), str(xf.src_ranges))
          else:
            if xf.patch:
              # We have already generated the patch with imgdiff, while
              # splitting large APKs (i.e. in FindTransfers()).
              assert not self.disable_imgdiff
              imgdiff = True
            if xf.patch_info:
              # We have already generated the patch (e.g. during split of large
              # APKs or reduction of stash size)
              imgdiff = xf.patch_info.imgdiff
            else:
              imgdiff = self.CanUseImgdiff(
                  xf.tgt_name, xf.tgt_ranges, xf.src_ranges)
@@ -854,85 +849,16 @@ class BlockImageDiff(object):
        else:
          assert False, "unknown style " + xf.style

    if diff_queue:
      if self.threads > 1:
        logger.info("Computing patches (using %d threads)...", self.threads)
      else:
        logger.info("Computing patches...")

      diff_total = len(diff_queue)
      patches = [None] * diff_total
      error_messages = []

      # Using multiprocessing doesn't give additional benefits, due to the
      # pattern of the code. The diffing work is done by subprocess.call, which
      # already runs in a separate process (not affected much by the GIL -
      # Global Interpreter Lock). Using multiprocess also requires either a)
      # writing the diff input files in the main process before forking, or b)
      # reopening the image file (SparseImage) in the worker processes. Doing
      # neither of them further improves the performance.
      lock = threading.Lock()
      def diff_worker():
        while True:
          with lock:
            if not diff_queue:
              return
            xf_index, imgdiff, patch_index = diff_queue.pop()
            xf = self.transfers[xf_index]

          patch = xf.patch
          if not patch:
            src_ranges = xf.src_ranges
            tgt_ranges = xf.tgt_ranges

            src_file = common.MakeTempFile(prefix="src-")
            with open(src_file, "wb") as fd:
              self.src.WriteRangeDataToFd(src_ranges, fd)

            tgt_file = common.MakeTempFile(prefix="tgt-")
            with open(tgt_file, "wb") as fd:
              self.tgt.WriteRangeDataToFd(tgt_ranges, fd)

            message = []
            try:
              patch = compute_patch(src_file, tgt_file, imgdiff)
            except ValueError as e:
              message.append(
                  "Failed to generate %s for %s: tgt=%s, src=%s:\n%s" % (
                      "imgdiff" if imgdiff else "bsdiff",
                      xf.tgt_name if xf.tgt_name == xf.src_name else
                      xf.tgt_name + " (from " + xf.src_name + ")",
                      xf.tgt_ranges, xf.src_ranges, e.message))
            if message:
              with lock:
                error_messages.extend(message)

          with lock:
            patches[patch_index] = (xf_index, patch)

      threads = [threading.Thread(target=diff_worker)
                 for _ in range(self.threads)]
      for th in threads:
        th.start()
      while threads:
        threads.pop().join()

      if error_messages:
        logger.error('ERROR:')
        logger.error('\n'.join(error_messages))
        logger.error('\n\n\n')
        sys.exit(1)
    else:
      patches = []
    patches = self.ComputePatchesForInputList(diff_queue, False)

    offset = 0
    with open(prefix + ".patch.dat", "wb") as patch_fd:
      for index, patch in patches:
      for index, patch_info, _ in patches:
        xf = self.transfers[index]
        xf.patch_len = len(patch)
        xf.patch_len = len(patch_info.content)
        xf.patch_start = offset
        offset += xf.patch_len
        patch_fd.write(patch)
        patch_fd.write(patch_info.content)

        tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
        logger.info(
@@ -999,6 +925,32 @@ class BlockImageDiff(object):
      for i in range(s, e):
        assert touched[i] == 1

  def FindSequenceForTransfers(self):
    """Finds a sequence for the given transfers.

     The goal is to minimize the violation of order dependencies between these
     transfers, so that fewer blocks are stashed when applying the update.
    """

    # Clear the existing dependency between transfers
    for xf in self.transfers:
      xf.goes_before = OrderedDict()
      xf.goes_after = OrderedDict()

      xf.stash_before = []
      xf.use_stash = []

    # Find the ordering dependencies among transfers (this is O(n^2)
    # in the number of transfers).
    self.GenerateDigraph()
    # Find a sequence of transfers that satisfies as many ordering
    # dependencies as possible (heuristically).
    self.FindVertexSequence()
    # Fix up the ordering dependencies that the sequence didn't
    # satisfy.
    self.ReverseBackwardEdges()
    self.ImproveVertexSequence()

  def ImproveVertexSequence(self):
    logger.info("Improving vertex order...")

@@ -1248,6 +1200,105 @@ class BlockImageDiff(object):
          b.goes_before[a] = size
          a.goes_after[b] = size

  def ComputePatchesForInputList(self, diff_queue, compress_target):
    """Returns a list of patch information for the input list of transfers.

      Args:
        diff_queue: a list of transfers with style 'diff'
        compress_target: If True, compresses the target ranges of each
            transfers; and save the size.

      Returns:
        A list of (transfer order, patch_info, compressed_size) tuples.
    """

    if not diff_queue:
      return []

    if self.threads > 1:
      logger.info("Computing patches (using %d threads)...", self.threads)
    else:
      logger.info("Computing patches...")

    diff_total = len(diff_queue)
    patches = [None] * diff_total
    error_messages = []

    # Using multiprocessing doesn't give additional benefits, due to the
    # pattern of the code. The diffing work is done by subprocess.call, which
    # already runs in a separate process (not affected much by the GIL -
    # Global Interpreter Lock). Using multiprocess also requires either a)
    # writing the diff input files in the main process before forking, or b)
    # reopening the image file (SparseImage) in the worker processes. Doing
    # neither of them further improves the performance.
    lock = threading.Lock()

    def diff_worker():
      while True:
        with lock:
          if not diff_queue:
            return
          xf_index, imgdiff, patch_index = diff_queue.pop()
          xf = self.transfers[xf_index]

        message = []
        compressed_size = None

        patch_info = xf.patch_info
        if not patch_info:
          src_file = common.MakeTempFile(prefix="src-")
          with open(src_file, "wb") as fd:
            self.src.WriteRangeDataToFd(xf.src_ranges, fd)

          tgt_file = common.MakeTempFile(prefix="tgt-")
          with open(tgt_file, "wb") as fd:
            self.tgt.WriteRangeDataToFd(xf.tgt_ranges, fd)

          try:
            patch_info = compute_patch(src_file, tgt_file, imgdiff)
          except ValueError as e:
            message.append(
                "Failed to generate %s for %s: tgt=%s, src=%s:\n%s" % (
                    "imgdiff" if imgdiff else "bsdiff",
                    xf.tgt_name if xf.tgt_name == xf.src_name else
                    xf.tgt_name + " (from " + xf.src_name + ")",
                    xf.tgt_ranges, xf.src_ranges, e.message))

        if compress_target:
          tgt_data = self.tgt.ReadRangeSet(xf.tgt_ranges)
          try:
            # Compresses with the default level
            compress_obj = zlib.compressobj(6, zlib.DEFLATED, -zlib.MAX_WBITS)
            compressed_data = (compress_obj.compress("".join(tgt_data))
                               + compress_obj.flush())
            compressed_size = len(compressed_data)
          except zlib.error as e:
            message.append(
                "Failed to compress the data in target range {} for {}:\n"
                "{}".format(xf.tgt_ranges, xf.tgt_name, e.message))

        if message:
          with lock:
            error_messages.extend(message)

        with lock:
          patches[patch_index] = (xf_index, patch_info, compressed_size)

    threads = [threading.Thread(target=diff_worker)
               for _ in range(self.threads)]
    for th in threads:
      th.start()
    while threads:
      threads.pop().join()

    if error_messages:
      logger.error('ERROR:')
      logger.error('\n'.join(error_messages))
      logger.error('\n\n\n')
      sys.exit(1)

    return patches

  def FindTransfers(self):
    """Parse the file_map to generate all the transfers."""

@@ -1585,7 +1636,7 @@ class BlockImageDiff(object):
                                self.tgt.RangeSha1(tgt_ranges),
                                self.src.RangeSha1(src_ranges),
                                "diff", self.transfers)
      transfer_split.patch = patch
      transfer_split.patch_info = PatchInfo(True, patch)

  def AbbreviateSourceNames(self):
    for k in self.src.file_map.keys():