michael@0: // michael@0: // Copyright (c) 2002-2013 The ANGLE Project Authors. All rights reserved. michael@0: // Use of this source code is governed by a BSD-style license that can be michael@0: // found in the LICENSE file. michael@0: // michael@0: michael@0: #include "compiler/ForLoopUnroll.h" michael@0: michael@0: namespace { michael@0: michael@0: class IntegerForLoopUnrollMarker : public TIntermTraverser { michael@0: public: michael@0: michael@0: virtual bool visitLoop(Visit, TIntermLoop* node) michael@0: { michael@0: // This is called after ValidateLimitations pass, so all the ASSERT michael@0: // should never fail. michael@0: // See ValidateLimitations::validateForLoopInit(). michael@0: ASSERT(node); michael@0: ASSERT(node->getType() == ELoopFor); michael@0: ASSERT(node->getInit()); michael@0: TIntermAggregate* decl = node->getInit()->getAsAggregate(); michael@0: ASSERT(decl && decl->getOp() == EOpDeclaration); michael@0: TIntermSequence& declSeq = decl->getSequence(); michael@0: ASSERT(declSeq.size() == 1); michael@0: TIntermBinary* declInit = declSeq[0]->getAsBinaryNode(); michael@0: ASSERT(declInit && declInit->getOp() == EOpInitialize); michael@0: ASSERT(declInit->getLeft()); michael@0: TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode(); michael@0: ASSERT(symbol); michael@0: TBasicType type = symbol->getBasicType(); michael@0: ASSERT(type == EbtInt || type == EbtFloat); michael@0: if (type == EbtInt) michael@0: node->setUnrollFlag(true); michael@0: return true; michael@0: } michael@0: michael@0: }; michael@0: michael@0: } // anonymous namepsace michael@0: michael@0: void ForLoopUnroll::FillLoopIndexInfo(TIntermLoop* node, TLoopIndexInfo& info) michael@0: { michael@0: ASSERT(node->getType() == ELoopFor); michael@0: ASSERT(node->getUnrollFlag()); michael@0: michael@0: TIntermNode* init = node->getInit(); michael@0: ASSERT(init != NULL); michael@0: TIntermAggregate* decl = init->getAsAggregate(); michael@0: ASSERT((decl != NULL) && (decl->getOp() == EOpDeclaration)); michael@0: TIntermSequence& declSeq = decl->getSequence(); michael@0: ASSERT(declSeq.size() == 1); michael@0: TIntermBinary* declInit = declSeq[0]->getAsBinaryNode(); michael@0: ASSERT((declInit != NULL) && (declInit->getOp() == EOpInitialize)); michael@0: TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode(); michael@0: ASSERT(symbol != NULL); michael@0: ASSERT(symbol->getBasicType() == EbtInt); michael@0: michael@0: info.id = symbol->getId(); michael@0: michael@0: ASSERT(declInit->getRight() != NULL); michael@0: TIntermConstantUnion* initNode = declInit->getRight()->getAsConstantUnion(); michael@0: ASSERT(initNode != NULL); michael@0: michael@0: info.initValue = evaluateIntConstant(initNode); michael@0: info.currentValue = info.initValue; michael@0: michael@0: TIntermNode* cond = node->getCondition(); michael@0: ASSERT(cond != NULL); michael@0: TIntermBinary* binOp = cond->getAsBinaryNode(); michael@0: ASSERT(binOp != NULL); michael@0: ASSERT(binOp->getRight() != NULL); michael@0: ASSERT(binOp->getRight()->getAsConstantUnion() != NULL); michael@0: michael@0: info.incrementValue = getLoopIncrement(node); michael@0: info.stopValue = evaluateIntConstant( michael@0: binOp->getRight()->getAsConstantUnion()); michael@0: info.op = binOp->getOp(); michael@0: } michael@0: michael@0: void ForLoopUnroll::Step() michael@0: { michael@0: ASSERT(mLoopIndexStack.size() > 0); michael@0: TLoopIndexInfo& info = mLoopIndexStack[mLoopIndexStack.size() - 1]; michael@0: info.currentValue += info.incrementValue; michael@0: } michael@0: michael@0: bool ForLoopUnroll::SatisfiesLoopCondition() michael@0: { michael@0: ASSERT(mLoopIndexStack.size() > 0); michael@0: TLoopIndexInfo& info = mLoopIndexStack[mLoopIndexStack.size() - 1]; michael@0: // Relational operator is one of: > >= < <= == or !=. michael@0: switch (info.op) { michael@0: case EOpEqual: michael@0: return (info.currentValue == info.stopValue); michael@0: case EOpNotEqual: michael@0: return (info.currentValue != info.stopValue); michael@0: case EOpLessThan: michael@0: return (info.currentValue < info.stopValue); michael@0: case EOpGreaterThan: michael@0: return (info.currentValue > info.stopValue); michael@0: case EOpLessThanEqual: michael@0: return (info.currentValue <= info.stopValue); michael@0: case EOpGreaterThanEqual: michael@0: return (info.currentValue >= info.stopValue); michael@0: default: michael@0: UNREACHABLE(); michael@0: } michael@0: return false; michael@0: } michael@0: michael@0: bool ForLoopUnroll::NeedsToReplaceSymbolWithValue(TIntermSymbol* symbol) michael@0: { michael@0: for (TVector::iterator i = mLoopIndexStack.begin(); michael@0: i != mLoopIndexStack.end(); michael@0: ++i) { michael@0: if (i->id == symbol->getId()) michael@0: return true; michael@0: } michael@0: return false; michael@0: } michael@0: michael@0: int ForLoopUnroll::GetLoopIndexValue(TIntermSymbol* symbol) michael@0: { michael@0: for (TVector::iterator i = mLoopIndexStack.begin(); michael@0: i != mLoopIndexStack.end(); michael@0: ++i) { michael@0: if (i->id == symbol->getId()) michael@0: return i->currentValue; michael@0: } michael@0: UNREACHABLE(); michael@0: return false; michael@0: } michael@0: michael@0: void ForLoopUnroll::Push(TLoopIndexInfo& info) michael@0: { michael@0: mLoopIndexStack.push_back(info); michael@0: } michael@0: michael@0: void ForLoopUnroll::Pop() michael@0: { michael@0: mLoopIndexStack.pop_back(); michael@0: } michael@0: michael@0: // static michael@0: void ForLoopUnroll::MarkForLoopsWithIntegerIndicesForUnrolling( michael@0: TIntermNode* root) michael@0: { michael@0: ASSERT(root); michael@0: michael@0: IntegerForLoopUnrollMarker marker; michael@0: root->traverse(&marker); michael@0: } michael@0: michael@0: int ForLoopUnroll::getLoopIncrement(TIntermLoop* node) michael@0: { michael@0: TIntermNode* expr = node->getExpression(); michael@0: ASSERT(expr != NULL); michael@0: // for expression has one of the following forms: michael@0: // loop_index++ michael@0: // loop_index-- michael@0: // loop_index += constant_expression michael@0: // loop_index -= constant_expression michael@0: // ++loop_index michael@0: // --loop_index michael@0: // The last two forms are not specified in the spec, but I am assuming michael@0: // its an oversight. michael@0: TIntermUnary* unOp = expr->getAsUnaryNode(); michael@0: TIntermBinary* binOp = unOp ? NULL : expr->getAsBinaryNode(); michael@0: michael@0: TOperator op = EOpNull; michael@0: TIntermConstantUnion* incrementNode = NULL; michael@0: if (unOp != NULL) { michael@0: op = unOp->getOp(); michael@0: } else if (binOp != NULL) { michael@0: op = binOp->getOp(); michael@0: ASSERT(binOp->getRight() != NULL); michael@0: incrementNode = binOp->getRight()->getAsConstantUnion(); michael@0: ASSERT(incrementNode != NULL); michael@0: } michael@0: michael@0: int increment = 0; michael@0: // The operator is one of: ++ -- += -=. michael@0: switch (op) { michael@0: case EOpPostIncrement: michael@0: case EOpPreIncrement: michael@0: ASSERT((unOp != NULL) && (binOp == NULL)); michael@0: increment = 1; michael@0: break; michael@0: case EOpPostDecrement: michael@0: case EOpPreDecrement: michael@0: ASSERT((unOp != NULL) && (binOp == NULL)); michael@0: increment = -1; michael@0: break; michael@0: case EOpAddAssign: michael@0: ASSERT((unOp == NULL) && (binOp != NULL)); michael@0: increment = evaluateIntConstant(incrementNode); michael@0: break; michael@0: case EOpSubAssign: michael@0: ASSERT((unOp == NULL) && (binOp != NULL)); michael@0: increment = - evaluateIntConstant(incrementNode); michael@0: break; michael@0: default: michael@0: ASSERT(false); michael@0: } michael@0: michael@0: return increment; michael@0: } michael@0: michael@0: int ForLoopUnroll::evaluateIntConstant(TIntermConstantUnion* node) michael@0: { michael@0: ASSERT((node != NULL) && (node->getUnionArrayPointer() != NULL)); michael@0: return node->getIConst(0); michael@0: } michael@0: