michael@0: // michael@0: // Copyright (c) 2002-2010 The ANGLE Project Authors. All rights reserved. michael@0: // Use of this source code is governed by a BSD-style license that can be michael@0: // found in the LICENSE file. michael@0: // michael@0: michael@0: #include "compiler/ValidateLimitations.h" michael@0: #include "compiler/InfoSink.h" michael@0: #include "compiler/InitializeParseContext.h" michael@0: #include "compiler/ParseHelper.h" michael@0: michael@0: namespace { michael@0: bool IsLoopIndex(const TIntermSymbol* symbol, const TLoopStack& stack) { michael@0: for (TLoopStack::const_iterator i = stack.begin(); i != stack.end(); ++i) { michael@0: if (i->index.id == symbol->getId()) michael@0: return true; michael@0: } michael@0: return false; michael@0: } michael@0: michael@0: void MarkLoopForUnroll(const TIntermSymbol* symbol, TLoopStack& stack) { michael@0: for (TLoopStack::iterator i = stack.begin(); i != stack.end(); ++i) { michael@0: if (i->index.id == symbol->getId()) { michael@0: ASSERT(i->loop != NULL); michael@0: i->loop->setUnrollFlag(true); michael@0: return; michael@0: } michael@0: } michael@0: UNREACHABLE(); michael@0: } michael@0: michael@0: // Traverses a node to check if it represents a constant index expression. michael@0: // Definition: michael@0: // constant-index-expressions are a superset of constant-expressions. michael@0: // Constant-index-expressions can include loop indices as defined in michael@0: // GLSL ES 1.0 spec, Appendix A, section 4. michael@0: // The following are constant-index-expressions: michael@0: // - Constant expressions michael@0: // - Loop indices as defined in section 4 michael@0: // - Expressions composed of both of the above michael@0: class ValidateConstIndexExpr : public TIntermTraverser { michael@0: public: michael@0: ValidateConstIndexExpr(const TLoopStack& stack) michael@0: : mValid(true), mLoopStack(stack) {} michael@0: michael@0: // Returns true if the parsed node represents a constant index expression. michael@0: bool isValid() const { return mValid; } michael@0: michael@0: virtual void visitSymbol(TIntermSymbol* symbol) { michael@0: // Only constants and loop indices are allowed in a michael@0: // constant index expression. michael@0: if (mValid) { michael@0: mValid = (symbol->getQualifier() == EvqConst) || michael@0: IsLoopIndex(symbol, mLoopStack); michael@0: } michael@0: } michael@0: michael@0: private: michael@0: bool mValid; michael@0: const TLoopStack& mLoopStack; michael@0: }; michael@0: michael@0: // Traverses a node to check if it uses a loop index. michael@0: // If an int loop index is used in its body as a sampler array index, michael@0: // mark the loop for unroll. michael@0: class ValidateLoopIndexExpr : public TIntermTraverser { michael@0: public: michael@0: ValidateLoopIndexExpr(TLoopStack& stack) michael@0: : mUsesFloatLoopIndex(false), michael@0: mUsesIntLoopIndex(false), michael@0: mLoopStack(stack) {} michael@0: michael@0: bool usesFloatLoopIndex() const { return mUsesFloatLoopIndex; } michael@0: bool usesIntLoopIndex() const { return mUsesIntLoopIndex; } michael@0: michael@0: virtual void visitSymbol(TIntermSymbol* symbol) { michael@0: if (IsLoopIndex(symbol, mLoopStack)) { michael@0: switch (symbol->getBasicType()) { michael@0: case EbtFloat: michael@0: mUsesFloatLoopIndex = true; michael@0: break; michael@0: case EbtInt: michael@0: mUsesIntLoopIndex = true; michael@0: MarkLoopForUnroll(symbol, mLoopStack); michael@0: break; michael@0: default: michael@0: UNREACHABLE(); michael@0: } michael@0: } michael@0: } michael@0: michael@0: private: michael@0: bool mUsesFloatLoopIndex; michael@0: bool mUsesIntLoopIndex; michael@0: TLoopStack& mLoopStack; michael@0: }; michael@0: } // namespace michael@0: michael@0: ValidateLimitations::ValidateLimitations(ShShaderType shaderType, michael@0: TInfoSinkBase& sink) michael@0: : mShaderType(shaderType), michael@0: mSink(sink), michael@0: mNumErrors(0) michael@0: { michael@0: } michael@0: michael@0: bool ValidateLimitations::visitBinary(Visit, TIntermBinary* node) michael@0: { michael@0: // Check if loop index is modified in the loop body. michael@0: validateOperation(node, node->getLeft()); michael@0: michael@0: // Check indexing. michael@0: switch (node->getOp()) { michael@0: case EOpIndexDirect: michael@0: validateIndexing(node); michael@0: break; michael@0: case EOpIndexIndirect: michael@0: #if defined(__APPLE__) michael@0: // Loop unrolling is a work-around for a Mac Cg compiler bug where it michael@0: // crashes when a sampler array's index is also the loop index. michael@0: // Once Apple fixes this bug, we should remove the code in this CL. michael@0: // See http://codereview.appspot.com/4331048/. michael@0: if ((node->getLeft() != NULL) && (node->getRight() != NULL) && michael@0: (node->getLeft()->getAsSymbolNode())) { michael@0: TIntermSymbol* symbol = node->getLeft()->getAsSymbolNode(); michael@0: if (IsSampler(symbol->getBasicType()) && symbol->isArray()) { michael@0: ValidateLoopIndexExpr validate(mLoopStack); michael@0: node->getRight()->traverse(&validate); michael@0: if (validate.usesFloatLoopIndex()) { michael@0: error(node->getLine(), michael@0: "sampler array index is float loop index", michael@0: "for"); michael@0: } michael@0: } michael@0: } michael@0: #endif michael@0: validateIndexing(node); michael@0: break; michael@0: default: break; michael@0: } michael@0: return true; michael@0: } michael@0: michael@0: bool ValidateLimitations::visitUnary(Visit, TIntermUnary* node) michael@0: { michael@0: // Check if loop index is modified in the loop body. michael@0: validateOperation(node, node->getOperand()); michael@0: michael@0: return true; michael@0: } michael@0: michael@0: bool ValidateLimitations::visitAggregate(Visit, TIntermAggregate* node) michael@0: { michael@0: switch (node->getOp()) { michael@0: case EOpFunctionCall: michael@0: validateFunctionCall(node); michael@0: break; michael@0: default: michael@0: break; michael@0: } michael@0: return true; michael@0: } michael@0: michael@0: bool ValidateLimitations::visitLoop(Visit, TIntermLoop* node) michael@0: { michael@0: if (!validateLoopType(node)) michael@0: return false; michael@0: michael@0: TLoopInfo info; michael@0: memset(&info, 0, sizeof(TLoopInfo)); michael@0: info.loop = node; michael@0: if (!validateForLoopHeader(node, &info)) michael@0: return false; michael@0: michael@0: TIntermNode* body = node->getBody(); michael@0: if (body != NULL) { michael@0: mLoopStack.push_back(info); michael@0: body->traverse(this); michael@0: mLoopStack.pop_back(); michael@0: } michael@0: michael@0: // The loop is fully processed - no need to visit children. michael@0: return false; michael@0: } michael@0: michael@0: void ValidateLimitations::error(TSourceLoc loc, michael@0: const char *reason, const char* token) michael@0: { michael@0: mSink.prefix(EPrefixError); michael@0: mSink.location(loc); michael@0: mSink << "'" << token << "' : " << reason << "\n"; michael@0: ++mNumErrors; michael@0: } michael@0: michael@0: bool ValidateLimitations::withinLoopBody() const michael@0: { michael@0: return !mLoopStack.empty(); michael@0: } michael@0: michael@0: bool ValidateLimitations::isLoopIndex(const TIntermSymbol* symbol) const michael@0: { michael@0: return IsLoopIndex(symbol, mLoopStack); michael@0: } michael@0: michael@0: bool ValidateLimitations::validateLoopType(TIntermLoop* node) { michael@0: TLoopType type = node->getType(); michael@0: if (type == ELoopFor) michael@0: return true; michael@0: michael@0: // Reject while and do-while loops. michael@0: error(node->getLine(), michael@0: "This type of loop is not allowed", michael@0: type == ELoopWhile ? "while" : "do"); michael@0: return false; michael@0: } michael@0: michael@0: bool ValidateLimitations::validateForLoopHeader(TIntermLoop* node, michael@0: TLoopInfo* info) michael@0: { michael@0: ASSERT(node->getType() == ELoopFor); michael@0: michael@0: // michael@0: // The for statement has the form: michael@0: // for ( init-declaration ; condition ; expression ) statement michael@0: // michael@0: if (!validateForLoopInit(node, info)) michael@0: return false; michael@0: if (!validateForLoopCond(node, info)) michael@0: return false; michael@0: if (!validateForLoopExpr(node, info)) michael@0: return false; michael@0: michael@0: return true; michael@0: } michael@0: michael@0: bool ValidateLimitations::validateForLoopInit(TIntermLoop* node, michael@0: TLoopInfo* info) michael@0: { michael@0: TIntermNode* init = node->getInit(); michael@0: if (init == NULL) { michael@0: error(node->getLine(), "Missing init declaration", "for"); michael@0: return false; michael@0: } michael@0: michael@0: // michael@0: // init-declaration has the form: michael@0: // type-specifier identifier = constant-expression michael@0: // michael@0: TIntermAggregate* decl = init->getAsAggregate(); michael@0: if ((decl == NULL) || (decl->getOp() != EOpDeclaration)) { michael@0: error(init->getLine(), "Invalid init declaration", "for"); michael@0: return false; michael@0: } michael@0: // To keep things simple do not allow declaration list. michael@0: TIntermSequence& declSeq = decl->getSequence(); michael@0: if (declSeq.size() != 1) { michael@0: error(decl->getLine(), "Invalid init declaration", "for"); michael@0: return false; michael@0: } michael@0: TIntermBinary* declInit = declSeq[0]->getAsBinaryNode(); michael@0: if ((declInit == NULL) || (declInit->getOp() != EOpInitialize)) { michael@0: error(decl->getLine(), "Invalid init declaration", "for"); michael@0: return false; michael@0: } michael@0: TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode(); michael@0: if (symbol == NULL) { michael@0: error(declInit->getLine(), "Invalid init declaration", "for"); michael@0: return false; michael@0: } michael@0: // The loop index has type int or float. michael@0: TBasicType type = symbol->getBasicType(); michael@0: if ((type != EbtInt) && (type != EbtFloat)) { michael@0: error(symbol->getLine(), michael@0: "Invalid type for loop index", getBasicString(type)); michael@0: return false; michael@0: } michael@0: // The loop index is initialized with constant expression. michael@0: if (!isConstExpr(declInit->getRight())) { michael@0: error(declInit->getLine(), michael@0: "Loop index cannot be initialized with non-constant expression", michael@0: symbol->getSymbol().c_str()); michael@0: return false; michael@0: } michael@0: michael@0: info->index.id = symbol->getId(); michael@0: return true; michael@0: } michael@0: michael@0: bool ValidateLimitations::validateForLoopCond(TIntermLoop* node, michael@0: TLoopInfo* info) michael@0: { michael@0: TIntermNode* cond = node->getCondition(); michael@0: if (cond == NULL) { michael@0: error(node->getLine(), "Missing condition", "for"); michael@0: return false; michael@0: } michael@0: // michael@0: // condition has the form: michael@0: // loop_index relational_operator constant_expression michael@0: // michael@0: TIntermBinary* binOp = cond->getAsBinaryNode(); michael@0: if (binOp == NULL) { michael@0: error(node->getLine(), "Invalid condition", "for"); michael@0: return false; michael@0: } michael@0: // Loop index should be to the left of relational operator. michael@0: TIntermSymbol* symbol = binOp->getLeft()->getAsSymbolNode(); michael@0: if (symbol == NULL) { michael@0: error(binOp->getLine(), "Invalid condition", "for"); michael@0: return false; michael@0: } michael@0: if (symbol->getId() != info->index.id) { michael@0: error(symbol->getLine(), michael@0: "Expected loop index", symbol->getSymbol().c_str()); michael@0: return false; michael@0: } michael@0: // Relational operator is one of: > >= < <= == or !=. michael@0: switch (binOp->getOp()) { michael@0: case EOpEqual: michael@0: case EOpNotEqual: michael@0: case EOpLessThan: michael@0: case EOpGreaterThan: michael@0: case EOpLessThanEqual: michael@0: case EOpGreaterThanEqual: michael@0: break; michael@0: default: michael@0: error(binOp->getLine(), michael@0: "Invalid relational operator", michael@0: getOperatorString(binOp->getOp())); michael@0: break; michael@0: } michael@0: // Loop index must be compared with a constant. michael@0: if (!isConstExpr(binOp->getRight())) { michael@0: error(binOp->getLine(), michael@0: "Loop index cannot be compared with non-constant expression", michael@0: symbol->getSymbol().c_str()); michael@0: return false; michael@0: } michael@0: michael@0: return true; michael@0: } michael@0: michael@0: bool ValidateLimitations::validateForLoopExpr(TIntermLoop* node, michael@0: TLoopInfo* info) michael@0: { michael@0: TIntermNode* expr = node->getExpression(); michael@0: if (expr == NULL) { michael@0: error(node->getLine(), "Missing expression", "for"); michael@0: return false; michael@0: } michael@0: michael@0: // for expression has one of the following forms: michael@0: // loop_index++ michael@0: // loop_index-- michael@0: // loop_index += constant_expression michael@0: // loop_index -= constant_expression michael@0: // ++loop_index michael@0: // --loop_index michael@0: // The last two forms are not specified in the spec, but I am assuming michael@0: // its an oversight. michael@0: TIntermUnary* unOp = expr->getAsUnaryNode(); michael@0: TIntermBinary* binOp = unOp ? NULL : expr->getAsBinaryNode(); michael@0: michael@0: TOperator op = EOpNull; michael@0: TIntermSymbol* symbol = NULL; michael@0: if (unOp != NULL) { michael@0: op = unOp->getOp(); michael@0: symbol = unOp->getOperand()->getAsSymbolNode(); michael@0: } else if (binOp != NULL) { michael@0: op = binOp->getOp(); michael@0: symbol = binOp->getLeft()->getAsSymbolNode(); michael@0: } michael@0: michael@0: // The operand must be loop index. michael@0: if (symbol == NULL) { michael@0: error(expr->getLine(), "Invalid expression", "for"); michael@0: return false; michael@0: } michael@0: if (symbol->getId() != info->index.id) { michael@0: error(symbol->getLine(), michael@0: "Expected loop index", symbol->getSymbol().c_str()); michael@0: return false; michael@0: } michael@0: michael@0: // The operator is one of: ++ -- += -=. michael@0: switch (op) { michael@0: case EOpPostIncrement: michael@0: case EOpPostDecrement: michael@0: case EOpPreIncrement: michael@0: case EOpPreDecrement: michael@0: ASSERT((unOp != NULL) && (binOp == NULL)); michael@0: break; michael@0: case EOpAddAssign: michael@0: case EOpSubAssign: michael@0: ASSERT((unOp == NULL) && (binOp != NULL)); michael@0: break; michael@0: default: michael@0: error(expr->getLine(), "Invalid operator", getOperatorString(op)); michael@0: return false; michael@0: } michael@0: michael@0: // Loop index must be incremented/decremented with a constant. michael@0: if (binOp != NULL) { michael@0: if (!isConstExpr(binOp->getRight())) { michael@0: error(binOp->getLine(), michael@0: "Loop index cannot be modified by non-constant expression", michael@0: symbol->getSymbol().c_str()); michael@0: return false; michael@0: } michael@0: } michael@0: michael@0: return true; michael@0: } michael@0: michael@0: bool ValidateLimitations::validateFunctionCall(TIntermAggregate* node) michael@0: { michael@0: ASSERT(node->getOp() == EOpFunctionCall); michael@0: michael@0: // If not within loop body, there is nothing to check. michael@0: if (!withinLoopBody()) michael@0: return true; michael@0: michael@0: // List of param indices for which loop indices are used as argument. michael@0: typedef std::vector ParamIndex; michael@0: ParamIndex pIndex; michael@0: TIntermSequence& params = node->getSequence(); michael@0: for (TIntermSequence::size_type i = 0; i < params.size(); ++i) { michael@0: TIntermSymbol* symbol = params[i]->getAsSymbolNode(); michael@0: if (symbol && isLoopIndex(symbol)) michael@0: pIndex.push_back(i); michael@0: } michael@0: // If none of the loop indices are used as arguments, michael@0: // there is nothing to check. michael@0: if (pIndex.empty()) michael@0: return true; michael@0: michael@0: bool valid = true; michael@0: TSymbolTable& symbolTable = GetGlobalParseContext()->symbolTable; michael@0: TSymbol* symbol = symbolTable.find(node->getName()); michael@0: ASSERT(symbol && symbol->isFunction()); michael@0: TFunction* function = static_cast(symbol); michael@0: for (ParamIndex::const_iterator i = pIndex.begin(); michael@0: i != pIndex.end(); ++i) { michael@0: const TParameter& param = function->getParam(*i); michael@0: TQualifier qual = param.type->getQualifier(); michael@0: if ((qual == EvqOut) || (qual == EvqInOut)) { michael@0: error(params[*i]->getLine(), michael@0: "Loop index cannot be used as argument to a function out or inout parameter", michael@0: params[*i]->getAsSymbolNode()->getSymbol().c_str()); michael@0: valid = false; michael@0: } michael@0: } michael@0: michael@0: return valid; michael@0: } michael@0: michael@0: bool ValidateLimitations::validateOperation(TIntermOperator* node, michael@0: TIntermNode* operand) { michael@0: // Check if loop index is modified in the loop body. michael@0: if (!withinLoopBody() || !node->modifiesState()) michael@0: return true; michael@0: michael@0: const TIntermSymbol* symbol = operand->getAsSymbolNode(); michael@0: if (symbol && isLoopIndex(symbol)) { michael@0: error(node->getLine(), michael@0: "Loop index cannot be statically assigned to within the body of the loop", michael@0: symbol->getSymbol().c_str()); michael@0: } michael@0: return true; michael@0: } michael@0: michael@0: bool ValidateLimitations::isConstExpr(TIntermNode* node) michael@0: { michael@0: ASSERT(node != NULL); michael@0: return node->getAsConstantUnion() != NULL; michael@0: } michael@0: michael@0: bool ValidateLimitations::isConstIndexExpr(TIntermNode* node) michael@0: { michael@0: ASSERT(node != NULL); michael@0: michael@0: ValidateConstIndexExpr validate(mLoopStack); michael@0: node->traverse(&validate); michael@0: return validate.isValid(); michael@0: } michael@0: michael@0: bool ValidateLimitations::validateIndexing(TIntermBinary* node) michael@0: { michael@0: ASSERT((node->getOp() == EOpIndexDirect) || michael@0: (node->getOp() == EOpIndexIndirect)); michael@0: michael@0: bool valid = true; michael@0: TIntermTyped* index = node->getRight(); michael@0: // The index expression must have integral type. michael@0: if (!index->isScalar() || (index->getBasicType() != EbtInt)) { michael@0: error(index->getLine(), michael@0: "Index expression must have integral type", michael@0: index->getCompleteString().c_str()); michael@0: valid = false; michael@0: } michael@0: // The index expession must be a constant-index-expression unless michael@0: // the operand is a uniform in a vertex shader. michael@0: TIntermTyped* operand = node->getLeft(); michael@0: bool skip = (mShaderType == SH_VERTEX_SHADER) && michael@0: (operand->getQualifier() == EvqUniform); michael@0: if (!skip && !isConstIndexExpr(index)) { michael@0: error(index->getLine(), "Index expression must be constant", "[]"); michael@0: valid = false; michael@0: } michael@0: return valid; michael@0: } michael@0: