js/src/devtools/jint/treesearch.py

changeset 0
6474c204b198
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/js/src/devtools/jint/treesearch.py	Wed Dec 31 06:09:35 2014 +0100
     1.3 @@ -0,0 +1,411 @@
     1.4 +# vim: set ts=8 sts=4 et sw=4 tw=99:
     1.5 +# This Source Code Form is subject to the terms of the Mozilla Public
     1.6 +# License, v. 2.0. If a copy of the MPL was not distributed with this
     1.7 +# file, You can obtain one at http://mozilla.org/MPL/2.0/.
     1.8 +
     1.9 +import os, re
    1.10 +import tempfile
    1.11 +import subprocess
    1.12 +import sys, math
    1.13 +import datetime
    1.14 +import random
    1.15 +
    1.16 +def realpath(k):
    1.17 +    return os.path.realpath(os.path.normpath(k))
    1.18 +
    1.19 +class UCTNode:
    1.20 +    def __init__(self, loop):
    1.21 +        self.children = None
    1.22 +        self.loop = loop
    1.23 +        self.visits = 1
    1.24 +        self.score = 0
    1.25 +
    1.26 +    def addChild(self, child):
    1.27 +        if self.children == None:
    1.28 +            self.children = []
    1.29 +        self.children.append(child)
    1.30 +
    1.31 +    def computeUCB(self, coeff):
    1.32 +        return (self.score / self.visits) + math.sqrt(coeff / self.visits)
    1.33 +
    1.34 +class UCT:
    1.35 +    def __init__(self, benchmark, bestTime, enableLoops, loops, fd, playouts):
    1.36 +        self.bm = benchmark
    1.37 +        self.fd = fd
    1.38 +        self.numPlayouts = playouts
    1.39 +        self.maxNodes = self.numPlayouts * 20
    1.40 +        self.loops = loops
    1.41 +        self.enableLoops = enableLoops
    1.42 +        self.maturityThreshold = 20
    1.43 +        self.originalBest = bestTime
    1.44 +        self.bestTime = bestTime
    1.45 +        self.bias = 20
    1.46 +        self.combos = []
    1.47 +        self.zobrist = { }
    1.48 +        random.seed()
    1.49 +
    1.50 +    def expandNode(self, node, pending):
    1.51 +        for loop in pending:
    1.52 +            node.addChild(UCTNode(loop))
    1.53 +            self.numNodes += 1
    1.54 +            if self.numNodes >= self.maxNodes:
    1.55 +                return False
    1.56 +        return True
    1.57 +
    1.58 +    def findBestChild(self, node):
    1.59 +        coeff = self.bias * math.log(node.visits)
    1.60 +        bestChild = None
    1.61 +        bestUCB = -float('Infinity')
    1.62 +
    1.63 +        for child in node.children:
    1.64 +            ucb = child.computeUCB(coeff)
    1.65 +            if ucb >= bestUCB:
    1.66 +                bestUCB = ucb
    1.67 +                bestChild = child
    1.68 +
    1.69 +        return child
    1.70 +
    1.71 +    def playout(self, history):
    1.72 +        queue = []
    1.73 +        for i in range(0, len(self.loops)):
    1.74 +            queue.append(random.randint(0, 1))
    1.75 +        for node in history:
    1.76 +            queue[node.loop] = not self.enableLoops
    1.77 +        zash = 0
    1.78 +        for i in range(0, len(queue)):
    1.79 +            if queue[i]:
    1.80 +                zash |= (1 << i)
    1.81 +        if zash in self.zobrist:
    1.82 +            return self.zobrist[zash]
    1.83 +
    1.84 +        self.bm.generateBanList(self.loops, queue)
    1.85 +        result = self.bm.treeSearchRun(self.fd, ['-m', '-j'], 3)
    1.86 +        self.zobrist[zash] = result
    1.87 +        return result
    1.88 +
    1.89 +    def step(self, loopList):
    1.90 +        node = self.root
    1.91 +        pending = loopList[:]
    1.92 +        history = [node]
    1.93 +
    1.94 +        while True:
    1.95 +            # If this is a leaf node...
    1.96 +            if node.children == None:
    1.97 +                # And the leaf node is mature...
    1.98 +                if node.visits >= self.maturityThreshold:
    1.99 +                    # If the node can be expanded, keep spinning.
   1.100 +                    if self.expandNode(node, pending) and node.children != None:
   1.101 +                        continue
   1.102 +
   1.103 +                # Otherwise, this is a leaf node. Run a playout.
   1.104 +                score = self.playout(history)
   1.105 +                break
   1.106 +
   1.107 +            # Find the best child.
   1.108 +            node = self.findBestChild(node)
   1.109 +            history.append(node)
   1.110 +            pending.remove(node.loop)
   1.111 +
   1.112 +        # Normalize the score.
   1.113 +        origScore = score
   1.114 +        score = (self.originalBest - score) / self.originalBest
   1.115 +
   1.116 +        for node in history:
   1.117 +            node.visits += 1
   1.118 +            node.score += score
   1.119 +
   1.120 +        if int(origScore) < int(self.bestTime):
   1.121 +            print('New best score: {0:f}ms'.format(origScore))
   1.122 +            self.combos = [history]
   1.123 +            self.bestTime = origScore
   1.124 +        elif int(origScore) == int(self.bestTime):
   1.125 +            self.combos.append(history)
   1.126 +
   1.127 +    def run(self):
   1.128 +        loopList = [i for i in range(0, len(self.loops))]
   1.129 +        self.numNodes = 1
   1.130 +        self.root = UCTNode(-1)
   1.131 +        self.expandNode(self.root, loopList)
   1.132 +
   1.133 +        for i in range(0, self.numPlayouts):
   1.134 +            self.step(loopList)
   1.135 +
   1.136 +        # Build the expected combination vector.
   1.137 +        combos = [ ]
   1.138 +        for combo in self.combos:
   1.139 +            vec = [ ]
   1.140 +            for i in range(0, len(self.loops)):
   1.141 +                vec.append(int(self.enableLoops))
   1.142 +            for node in combo:
   1.143 +                vec[node.loop] = int(not self.enableLoops)
   1.144 +            combos.append(vec)
   1.145 +
   1.146 +        return [self.bestTime, combos]
   1.147 +
   1.148 +class Benchmark:
   1.149 +    def __init__(self, JS, fname):
   1.150 +        self.fname = fname
   1.151 +        self.JS = JS
   1.152 +        self.stats = { }
   1.153 +        self.runList = [ ]
   1.154 +
   1.155 +    def run(self, fd, eargs):
   1.156 +        args = [self.JS]
   1.157 +        args.extend(eargs)
   1.158 +        args.append(fd.name)
   1.159 +        return subprocess.check_output(args).decode()
   1.160 +
   1.161 +    #    self.stats[name] = { }
   1.162 +    #    self.runList.append(name)
   1.163 +    #    for line in output.split('\n'):
   1.164 +    #        m = re.search('line (\d+): (\d+)', line)
   1.165 +    #        if m:
   1.166 +    #            self.stats[name][int(m.group(1))] = int(m.group(2))
   1.167 +    #        else:
   1.168 +    #            m = re.search('total: (\d+)', line)
   1.169 +    #            if m:
   1.170 +    #                self.stats[name]['total'] = m.group(1)
   1.171 +
   1.172 +    def winnerForLine(self, line):
   1.173 +        best = self.runList[0]
   1.174 +        bestTime = self.stats[best][line]
   1.175 +        for run in self.runList[1:]:
   1.176 +            x = self.stats[run][line]
   1.177 +            if x < bestTime:
   1.178 +                best = run
   1.179 +                bestTime = x
   1.180 +        return best
   1.181 +
   1.182 +    def chart(self):
   1.183 +        sys.stdout.write('{0:7s}'.format(''))
   1.184 +        sys.stdout.write('{0:15s}'.format('line'))
   1.185 +        for run in self.runList:
   1.186 +            sys.stdout.write('{0:15s}'.format(run))
   1.187 +        sys.stdout.write('{0:15s}\n'.format('best'))
   1.188 +        for c in self.counters:
   1.189 +            sys.stdout.write('{0:10d}'.format(c))
   1.190 +            for run in self.runList:
   1.191 +                sys.stdout.write('{0:15d}'.format(self.stats[run][c]))
   1.192 +            sys.stdout.write('{0:12s}'.format(''))
   1.193 +            sys.stdout.write('{0:15s}'.format(self.winnerForLine(c)))
   1.194 +            sys.stdout.write('\n')
   1.195 +
   1.196 +    def preprocess(self, lines, onBegin, onEnd):
   1.197 +        stack = []
   1.198 +        counters = []
   1.199 +        rd = open(self.fname, 'rt')
   1.200 +        for line in rd:
   1.201 +            if re.search('\/\* BEGIN LOOP \*\/', line):
   1.202 +                stack.append([len(lines), len(counters)])
   1.203 +                counters.append([len(lines), 0])
   1.204 +                onBegin(lines, len(lines))
   1.205 +            elif re.search('\/\* END LOOP \*\/', line):
   1.206 +                old = stack.pop()
   1.207 +                onEnd(lines, old[0], len(lines))
   1.208 +                counters[old[1]][1] = len(lines)
   1.209 +            else:
   1.210 +                lines.append(line)
   1.211 +        return [lines, counters]
   1.212 +
   1.213 +    def treeSearchRun(self, fd, args, count = 5):
   1.214 +        total = 0
   1.215 +        for i in range(0, count):
   1.216 +            output = self.run(fd, args)
   1.217 +            total += int(output)
   1.218 +        return total / count
   1.219 +
   1.220 +    def generateBanList(self, counters, queue):
   1.221 +        if os.path.exists('/tmp/permabans'):
   1.222 +            os.unlink('/tmp/permabans')
   1.223 +        fd = open('/tmp/permabans', 'wt')
   1.224 +        for i in range(0, len(counters)):
   1.225 +            for j in range(counters[i][0], counters[i][1] + 1):
   1.226 +                fd.write('{0:d} {1:d}\n'.format(j, int(queue[i])))
   1.227 +        fd.close()
   1.228 +
   1.229 +    def internalExhaustiveSearch(self, params):
   1.230 +        counters = params['counters']
   1.231 +
   1.232 +        # iterative algorithm to explore every combination
   1.233 +        ncombos = 2 ** len(counters)
   1.234 +        queue = []
   1.235 +        for c in counters:
   1.236 +            queue.append(0)
   1.237 +
   1.238 +        fd = params['fd']
   1.239 +        bestTime = float('Infinity')
   1.240 +        bestCombos = []
   1.241 +
   1.242 +        i = 0
   1.243 +        while i < ncombos:
   1.244 +            temp = i
   1.245 +            for j in range(0, len(counters)):
   1.246 +                queue[j] = temp & 1
   1.247 +                temp = temp >> 1
   1.248 +            self.generateBanList(counters, queue)
   1.249 +
   1.250 +            t = self.treeSearchRun(fd, ['-m', '-j'])
   1.251 +            if (t < bestTime):
   1.252 +                bestTime = t
   1.253 +                bestCombos = [queue[:]]
   1.254 +                print('New best time: {0:f}ms'.format(t))
   1.255 +            elif int(t) == int(bestTime):
   1.256 +                bestCombos.append(queue[:])
   1.257 +
   1.258 +            i = i + 1
   1.259 +
   1.260 +        return [bestTime, bestCombos]
   1.261 +
   1.262 +    def internalTreeSearch(self, params):
   1.263 +        fd = params['fd']
   1.264 +        methodTime = params['methodTime']
   1.265 +        tracerTime = params['tracerTime']
   1.266 +        combinedTime = params['combinedTime']
   1.267 +        counters = params['counters']
   1.268 +
   1.269 +        # Build the initial loop data.
   1.270 +        # If the method JIT already wins, disable tracing by default.
   1.271 +        # Otherwise, enable tracing by default.
   1.272 +        if methodTime < combinedTime:
   1.273 +            enableLoops = True
   1.274 +        else:
   1.275 +            enableLoops = False
   1.276 +
   1.277 +        enableLoops = False
   1.278 +
   1.279 +        uct = UCT(self, combinedTime, enableLoops, counters[:], fd, 50000)
   1.280 +        return uct.run()
   1.281 +
   1.282 +    def treeSearch(self):
   1.283 +        fd, counters = self.ppForTreeSearch()
   1.284 +
   1.285 +        os.system("cat " + fd.name + " > /tmp/k.js")
   1.286 +
   1.287 +        if os.path.exists('/tmp/permabans'):
   1.288 +            os.unlink('/tmp/permabans')
   1.289 +        methodTime = self.treeSearchRun(fd, ['-m'])
   1.290 +        tracerTime = self.treeSearchRun(fd, ['-j'])
   1.291 +        combinedTime = self.treeSearchRun(fd, ['-m', '-j'])
   1.292 +
   1.293 +        #Get a rough estimate of how long this benchmark will take to fully compute.
   1.294 +        upperBound = max(methodTime, tracerTime, combinedTime)
   1.295 +        upperBound *= 2 ** len(counters)
   1.296 +        upperBound *= 5    # Number of runs
   1.297 +        treeSearch = False
   1.298 +        if (upperBound < 1000):
   1.299 +            print('Estimating {0:d}ms to test, so picking exhaustive '.format(int(upperBound)) +
   1.300 +                  'search.')
   1.301 +        else:
   1.302 +            upperBound = int(upperBound / 1000)
   1.303 +            delta = datetime.timedelta(seconds = upperBound)
   1.304 +            if upperBound < 180:
   1.305 +                print('Estimating {0:d}s to test, so picking exhaustive '.format(int(upperBound)))
   1.306 +            else:
   1.307 +                print('Estimating {0:s} to test, so picking tree search '.format(str(delta)))
   1.308 +                treeSearch = True
   1.309 +
   1.310 +        best = min(methodTime, tracerTime, combinedTime)
   1.311 +
   1.312 +        params = {
   1.313 +                    'fd': fd,
   1.314 +                    'counters': counters,
   1.315 +                    'methodTime': methodTime,
   1.316 +                    'tracerTime': tracerTime,
   1.317 +                    'combinedTime': combinedTime
   1.318 +                 }
   1.319 +
   1.320 +        print('Method JIT:  {0:d}ms'.format(int(methodTime)))
   1.321 +        print('Tracing JIT: {0:d}ms'.format(int(tracerTime)))
   1.322 +        print('Combined:    {0:d}ms'.format(int(combinedTime)))
   1.323 +
   1.324 +        if 1 and treeSearch:
   1.325 +            results = self.internalTreeSearch(params)
   1.326 +        else:
   1.327 +            results = self.internalExhaustiveSearch(params)
   1.328 +
   1.329 +        bestTime = results[0]
   1.330 +        bestCombos = results[1]
   1.331 +        print('Search found winning time {0:d}ms!'.format(int(bestTime)))
   1.332 +        print('Combos at this time: {0:d}'.format(len(bestCombos)))
   1.333 +
   1.334 +        #Find loops that traced every single time
   1.335 +        for i in range(0, len(counters)):
   1.336 +            start = counters[i][0]
   1.337 +            end = counters[i][1]
   1.338 +            n = len(bestCombos)
   1.339 +            for j in bestCombos:
   1.340 +                n -= j[i]
   1.341 +            print('\tloop @ {0:d}-{1:d} traced {2:d}% of the time'.format(
   1.342 +                    start, end, int(n / len(bestCombos) * 100)))
   1.343 +
   1.344 +    def ppForTreeSearch(self):
   1.345 +        def onBegin(lines, lineno):
   1.346 +            lines.append('GLOBAL_THINGY = 1;\n')
   1.347 +        def onEnd(lines, old, lineno):
   1.348 +            lines.append('GLOBAL_THINGY = 1;\n')
   1.349 +
   1.350 +        lines = ['var JINT_START_TIME = Date.now();\n',
   1.351 +                 'var GLOBAL_THINGY = 0;\n']
   1.352 +
   1.353 +        lines, counters = self.preprocess(lines, onBegin, onEnd)
   1.354 +        fd = tempfile.NamedTemporaryFile('wt')
   1.355 +        for line in lines:
   1.356 +            fd.write(line)
   1.357 +        fd.write('print(Date.now() - JINT_START_TIME);\n')
   1.358 +        fd.flush()
   1.359 +        return [fd, counters]
   1.360 +
   1.361 +    def preprocessForLoopCounting(self):
   1.362 +        def onBegin(lines, lineno):
   1.363 +            lines.append('JINT_TRACKER.line_' + str(lineno) + '_start = Date.now();\n')
   1.364 +
   1.365 +        def onEnd(lines, old, lineno):
   1.366 +            lines.append('JINT_TRACKER.line_' + str(old) + '_end = Date.now();\n')
   1.367 +            lines.append('JINT_TRACKER.line_' + str(old) + '_total += ' + \
   1.368 +                         'JINT_TRACKER.line_' + str(old) + '_end - ' + \
   1.369 +                         'JINT_TRACKER.line_' + str(old) + '_start;\n')
   1.370 +
   1.371 +        lines, counters = self.preprocess(onBegin, onEnd)
   1.372 +        fd = tempfile.NamedTemporaryFile('wt')
   1.373 +        fd.write('var JINT_TRACKER = { };\n')
   1.374 +        for c in counters:
   1.375 +            fd.write('JINT_TRACKER.line_' + str(c) + '_start = 0;\n')
   1.376 +            fd.write('JINT_TRACKER.line_' + str(c) + '_end = 0;\n')
   1.377 +            fd.write('JINT_TRACKER.line_' + str(c) + '_total = 0;\n')
   1.378 +        fd.write('JINT_TRACKER.begin = Date.now();\n')
   1.379 +        for line in lines:
   1.380 +            fd.write(line)
   1.381 +        fd.write('JINT_TRACKER.total = Date.now() - JINT_TRACKER.begin;\n')
   1.382 +        for c in self.counters:
   1.383 +            fd.write('print("line ' + str(c) + ': " + JINT_TRACKER.line_' + str(c) +
   1.384 +                           '_total);')
   1.385 +        fd.write('print("total: " + JINT_TRACKER.total);')
   1.386 +        fd.flush()
   1.387 +        return fd
   1.388 +
   1.389 +if __name__ == '__main__':
   1.390 +    script_path = os.path.abspath(__file__)
   1.391 +    script_dir = os.path.dirname(script_path)
   1.392 +    test_dir = os.path.join(script_dir, 'tests')
   1.393 +    lib_dir = os.path.join(script_dir, 'lib')
   1.394 +
   1.395 +    # The [TESTS] optional arguments are paths of test files relative
   1.396 +    # to the jit-test/tests directory.
   1.397 +
   1.398 +    from optparse import OptionParser
   1.399 +    op = OptionParser(usage='%prog [options] JS_SHELL test')
   1.400 +    (OPTIONS, args) = op.parse_args()
   1.401 +    if len(args) < 2:
   1.402 +        op.error('missing JS_SHELL and test argument')
   1.403 +    # We need to make sure we are using backslashes on Windows.
   1.404 +    JS = realpath(args[0])
   1.405 +    test = realpath(args[1])
   1.406 +
   1.407 +    bm = Benchmark(JS, test)
   1.408 +    bm.treeSearch()
   1.409 +    # bm.preprocess()
   1.410 +    # bm.run('mjit', ['-m'])
   1.411 +    # bm.run('tjit', ['-j'])
   1.412 +    # bm.run('m+tjit', ['-m', '-j'])
   1.413 +    # bm.chart()
   1.414 +

mercurial