1/*
2 * Copyright (C) 2011 Apple Inc. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without
5 * modification, are permitted provided that the following conditions
6 * are met:
7 *
8 * 1.  Redistributions of source code must retain the above copyright
9 *     notice, this list of conditions and the following disclaimer.
10 * 2.  Redistributions in binary form must reproduce the above copyright
11 *     notice, this list of conditions and the following disclaimer in the
12 *     documentation and/or other materials provided with the distribution.
13 *
14 * THIS SOFTWARE IS PROVIDED BY APPLE AND ITS CONTRIBUTORS "AS IS" AND ANY
15 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
16 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17 * DISCLAIMED. IN NO EVENT SHALL APPLE OR ITS CONTRIBUTORS BE LIABLE FOR ANY
18 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
19 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
20 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
21 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 */
25
26#include "config.h"
27#include "core/dom/SelectorQuery.h"
28
29#include "bindings/v8/ExceptionState.h"
30#include "core/css/CSSParser.h"
31#include "core/css/CSSSelectorList.h"
32#include "core/css/SelectorChecker.h"
33#include "core/css/SelectorCheckerFastPath.h"
34#include "core/css/SiblingTraversalStrategies.h"
35#include "core/dom/Document.h"
36#include "core/dom/NodeTraversal.h"
37#include "core/dom/StaticNodeList.h"
38
39namespace WebCore {
40
41class SimpleNodeList {
42public:
43    virtual ~SimpleNodeList() { }
44    virtual bool isEmpty() const = 0;
45    virtual Node* next() = 0;
46};
47
48class SingleNodeList : public SimpleNodeList {
49public:
50    explicit SingleNodeList(Node* rootNode) : m_currentNode(rootNode) { }
51
52    bool isEmpty() const { return !m_currentNode; }
53
54    Node* next()
55    {
56        Node* current = m_currentNode;
57        m_currentNode = 0;
58        return current;
59    }
60
61private:
62    Node* m_currentNode;
63};
64
65class ClassRootNodeList : public SimpleNodeList {
66public:
67    explicit ClassRootNodeList(Node* rootNode, const AtomicString& className)
68        : m_className(className)
69        , m_rootNode(rootNode)
70        , m_currentElement(nextInternal(ElementTraversal::firstWithin(rootNode))) { }
71
72    bool isEmpty() const { return !m_currentElement; }
73
74    Node* next()
75    {
76        Node* current = m_currentElement;
77        ASSERT(current);
78        m_currentElement = nextInternal(ElementTraversal::nextSkippingChildren(m_currentElement, m_rootNode));
79        return current;
80    }
81
82private:
83    Element* nextInternal(Element* element)
84    {
85        for (; element; element = ElementTraversal::next(element, m_rootNode)) {
86            if (element->hasClass() && element->classNames().contains(m_className))
87                return element;
88        }
89        return 0;
90    }
91
92    const AtomicString& m_className;
93    Node* m_rootNode;
94    Element* m_currentElement;
95};
96
97class ClassElementList : public SimpleNodeList {
98public:
99    explicit ClassElementList(Node* rootNode, const AtomicString& className)
100        : m_className(className)
101        , m_rootNode(rootNode)
102        , m_currentElement(nextInternal(ElementTraversal::firstWithin(rootNode))) { }
103
104    bool isEmpty() const { return !m_currentElement; }
105
106    Node* next()
107    {
108        Node* current = m_currentElement;
109        ASSERT(current);
110        m_currentElement = nextInternal(ElementTraversal::next(m_currentElement, m_rootNode));
111        return current;
112    }
113
114private:
115    Element* nextInternal(Element* element)
116    {
117        for (; element; element = ElementTraversal::next(element, m_rootNode)) {
118            if (element->hasClass() && element->classNames().contains(m_className))
119                return element;
120        }
121        return 0;
122    }
123
124    const AtomicString& m_className;
125    Node* m_rootNode;
126    Element* m_currentElement;
127};
128
129void SelectorDataList::initialize(const CSSSelectorList& selectorList)
130{
131    ASSERT(m_selectors.isEmpty());
132
133    unsigned selectorCount = 0;
134    for (const CSSSelector* selector = selectorList.first(); selector; selector = CSSSelectorList::next(selector))
135        selectorCount++;
136
137    m_selectors.reserveInitialCapacity(selectorCount);
138    for (const CSSSelector* selector = selectorList.first(); selector; selector = CSSSelectorList::next(selector))
139        m_selectors.uncheckedAppend(SelectorData(selector, SelectorCheckerFastPath::canUse(selector)));
140}
141
142inline bool SelectorDataList::selectorMatches(const SelectorData& selectorData, Element* element, const Node* rootNode) const
143{
144    if (selectorData.isFastCheckable && !element->isSVGElement()) {
145        SelectorCheckerFastPath selectorCheckerFastPath(selectorData.selector, element);
146        if (!selectorCheckerFastPath.matchesRightmostSelector(SelectorChecker::VisitedMatchDisabled))
147            return false;
148        return selectorCheckerFastPath.matches();
149    }
150
151    SelectorChecker selectorChecker(element->document(), SelectorChecker::QueryingRules);
152    SelectorChecker::SelectorCheckingContext selectorCheckingContext(selectorData.selector, element, SelectorChecker::VisitedMatchDisabled);
153    selectorCheckingContext.behaviorAtBoundary = SelectorChecker::StaysWithinTreeScope;
154    selectorCheckingContext.scope = !rootNode->isDocumentNode() && rootNode->isContainerNode() ? toContainerNode(rootNode) : 0;
155    PseudoId ignoreDynamicPseudo = NOPSEUDO;
156    return selectorChecker.match(selectorCheckingContext, ignoreDynamicPseudo, DOMSiblingTraversalStrategy()) == SelectorChecker::SelectorMatches;
157}
158
159bool SelectorDataList::matches(Element* targetElement) const
160{
161    ASSERT(targetElement);
162
163    unsigned selectorCount = m_selectors.size();
164    for (unsigned i = 0; i < selectorCount; ++i) {
165        if (selectorMatches(m_selectors[i], targetElement, targetElement))
166            return true;
167    }
168
169    return false;
170}
171
172PassRefPtr<NodeList> SelectorDataList::queryAll(Node* rootNode) const
173{
174    Vector<RefPtr<Node> > result;
175    executeQueryAll(rootNode, result);
176    return StaticNodeList::adopt(result);
177}
178
179PassRefPtr<Element> SelectorDataList::queryFirst(Node* rootNode) const
180{
181    return executeQueryFirst(rootNode);
182}
183
184static inline bool isTreeScopeRoot(Node* node)
185{
186    ASSERT(node);
187    return node->isDocumentNode() || node->isShadowRoot();
188}
189
190void SelectorDataList::collectElementsByClassName(Node* rootNode, const AtomicString& className, Vector<RefPtr<Node> >& traversalRoots) const
191{
192    for (Element* element = ElementTraversal::firstWithin(rootNode); element; element = ElementTraversal::next(element, rootNode)) {
193        if (element->hasClass() && element->classNames().contains(className))
194            traversalRoots.append(element);
195    }
196}
197
198void SelectorDataList::collectElementsByTagName(Node* rootNode, const QualifiedName& tagName, Vector<RefPtr<Node> >& traversalRoots) const
199{
200    for (Element* element = ElementTraversal::firstWithin(rootNode); element; element = ElementTraversal::next(element, rootNode)) {
201        if (SelectorChecker::tagMatches(element, tagName))
202            traversalRoots.append(element);
203    }
204}
205
206Element* SelectorDataList::findElementByClassName(Node* rootNode, const AtomicString& className) const
207{
208    for (Element* element = ElementTraversal::firstWithin(rootNode); element; element = ElementTraversal::next(element, rootNode)) {
209        if (element->hasClass() && element->classNames().contains(className))
210            return element;
211    }
212    return 0;
213}
214
215Element* SelectorDataList::findElementByTagName(Node* rootNode, const QualifiedName& tagName) const
216{
217    for (Element* element = ElementTraversal::firstWithin(rootNode); element; element = ElementTraversal::next(element, rootNode)) {
218        if (SelectorChecker::tagMatches(element, tagName))
219            return element;
220    }
221    return 0;
222}
223
224inline bool SelectorDataList::canUseFastQuery(Node* rootNode) const
225{
226    return m_selectors.size() == 1 && rootNode->inDocument() && !rootNode->document()->inQuirksMode();
227}
228
229// If returns true, traversalRoots has the elements that may match the selector query.
230//
231// If returns false, traversalRoots has the rootNode parameter or descendants of rootNode representing
232// the subtree for which we can limit the querySelector traversal.
233//
234// The travseralRoots may be empty, regardless of the returned bool value, if this method finds that the selectors won't
235// match any element.
236PassOwnPtr<SimpleNodeList> SelectorDataList::findTraverseRoots(Node* rootNode, bool& matchTraverseRoots) const
237{
238    // We need to return the matches in document order. To use id lookup while there is possiblity of multiple matches
239    // we would need to sort the results. For now, just traverse the document in that case.
240    ASSERT(rootNode);
241    ASSERT(m_selectors.size() == 1);
242    ASSERT(m_selectors[0].selector);
243
244    bool isRightmostSelector = true;
245    bool startFromParent = false;
246
247    for (const CSSSelector* selector = m_selectors[0].selector; selector; selector = selector->tagHistory()) {
248        if (selector->m_match == CSSSelector::Id && !rootNode->document()->containsMultipleElementsWithId(selector->value())) {
249            Element* element = rootNode->treeScope()->getElementById(selector->value());
250            if (element && (isTreeScopeRoot(rootNode) || element->isDescendantOf(rootNode)))
251                rootNode = element;
252            else if (!element || isRightmostSelector)
253                rootNode = 0;
254            if (isRightmostSelector) {
255                matchTraverseRoots = true;
256                return adoptPtr(new SingleNodeList(rootNode));
257            }
258            if (startFromParent && rootNode)
259                rootNode = rootNode->parentNode();
260
261            matchTraverseRoots = false;
262            return adoptPtr(new SingleNodeList(rootNode));
263        }
264
265        // If we have both CSSSelector::Id and CSSSelector::Class at the same time, we should use Id
266        // to find traverse root.
267        if (!startFromParent && selector->m_match == CSSSelector::Class) {
268            if (isRightmostSelector) {
269                matchTraverseRoots = true;
270                return adoptPtr(new ClassElementList(rootNode, selector->value()));
271            }
272            matchTraverseRoots = false;
273            return adoptPtr(new ClassRootNodeList(rootNode, selector->value()));
274        }
275
276        if (selector->relation() == CSSSelector::SubSelector)
277            continue;
278        isRightmostSelector = false;
279        if (selector->relation() == CSSSelector::DirectAdjacent || selector->relation() == CSSSelector::IndirectAdjacent)
280            startFromParent = true;
281        else
282            startFromParent = false;
283    }
284
285    matchTraverseRoots = false;
286    return adoptPtr(new SingleNodeList(rootNode));
287}
288
289void SelectorDataList::executeSlowQueryAll(Node* rootNode, Vector<RefPtr<Node> >& matchedElements) const
290{
291    for (Element* element = ElementTraversal::firstWithin(rootNode); element; element = ElementTraversal::next(element, rootNode)) {
292        for (unsigned i = 0; i < m_selectors.size(); ++i) {
293            if (selectorMatches(m_selectors[i], element, rootNode)) {
294                matchedElements.append(element);
295                break;
296            }
297        }
298    }
299}
300
301void SelectorDataList::executeQueryAll(Node* rootNode, Vector<RefPtr<Node> >& matchedElements) const
302{
303    if (!canUseFastQuery(rootNode))
304        return executeSlowQueryAll(rootNode, matchedElements);
305
306    ASSERT(m_selectors.size() == 1);
307    ASSERT(m_selectors[0].selector);
308
309    const CSSSelector* firstSelector = m_selectors[0].selector;
310
311    if (!firstSelector->tagHistory()) {
312        // Fast path for querySelectorAll('#id'), querySelectorAl('.foo'), and querySelectorAll('div').
313        switch (firstSelector->m_match) {
314        case CSSSelector::Id:
315            {
316                if (rootNode->document()->containsMultipleElementsWithId(firstSelector->value()))
317                    break;
318
319                // Just the same as getElementById.
320                Element* element = rootNode->treeScope()->getElementById(firstSelector->value());
321                if (element && (isTreeScopeRoot(rootNode) || element->isDescendantOf(rootNode)))
322                    matchedElements.append(element);
323                return;
324            }
325        case CSSSelector::Class:
326            return collectElementsByClassName(rootNode, firstSelector->value(), matchedElements);
327        case CSSSelector::Tag:
328            return collectElementsByTagName(rootNode, firstSelector->tagQName(), matchedElements);
329        default:
330            break; // If we need another fast path, add here.
331        }
332    }
333
334    bool matchTraverseRoots;
335    OwnPtr<SimpleNodeList> traverseRoots = findTraverseRoots(rootNode, matchTraverseRoots);
336    if (traverseRoots->isEmpty())
337        return;
338
339    const SelectorData& selector = m_selectors[0];
340    if (matchTraverseRoots) {
341        while (!traverseRoots->isEmpty()) {
342            Node* node = traverseRoots->next();
343            Element* element = toElement(node);
344            if (selectorMatches(selector, element, rootNode))
345                matchedElements.append(element);
346        }
347        return;
348    }
349
350    while (!traverseRoots->isEmpty()) {
351        Node* traverseRoot = traverseRoots->next();
352        for (Element* element = ElementTraversal::firstWithin(traverseRoot); element; element = ElementTraversal::next(element, traverseRoot)) {
353            if (selectorMatches(selector, element, rootNode))
354                matchedElements.append(element);
355        }
356    }
357}
358
359// If matchTraverseRoot is true, the returned Node is the single Element that may match the selector query.
360//
361// If matchTraverseRoot is false, the returned Node is the rootNode parameter or a descendant of rootNode representing
362// the subtree for which we can limit the querySelector traversal.
363//
364// The returned Node may be 0, regardless of matchTraverseRoot, if this method finds that the selectors won't
365// match any element.
366Node* SelectorDataList::findTraverseRoot(Node* rootNode, bool& matchTraverseRoot) const
367{
368    // We need to return the matches in document order. To use id lookup while there is possiblity of multiple matches
369    // we would need to sort the results. For now, just traverse the document in that case.
370    ASSERT(rootNode);
371    ASSERT(m_selectors.size() == 1);
372    ASSERT(m_selectors[0].selector);
373
374    bool matchSingleNode = true;
375    bool startFromParent = false;
376    for (const CSSSelector* selector = m_selectors[0].selector; selector; selector = selector->tagHistory()) {
377        if (selector->m_match == CSSSelector::Id && !rootNode->document()->containsMultipleElementsWithId(selector->value())) {
378            Element* element = rootNode->treeScope()->getElementById(selector->value());
379            if (element && (isTreeScopeRoot(rootNode) || element->isDescendantOf(rootNode)))
380                rootNode = element;
381            else if (!element || matchSingleNode)
382                rootNode = 0;
383            if (matchSingleNode) {
384                matchTraverseRoot = true;
385                return rootNode;
386            }
387            if (startFromParent && rootNode)
388                rootNode = rootNode->parentNode();
389            matchTraverseRoot = false;
390            return rootNode;
391        }
392        if (selector->relation() == CSSSelector::SubSelector)
393            continue;
394        matchSingleNode = false;
395        if (selector->relation() == CSSSelector::DirectAdjacent || selector->relation() == CSSSelector::IndirectAdjacent)
396            startFromParent = true;
397        else
398            startFromParent = false;
399    }
400    matchTraverseRoot = false;
401    return rootNode;
402}
403
404Element* SelectorDataList::executeSlowQueryFirst(Node* rootNode) const
405{
406    for (Element* element = ElementTraversal::firstWithin(rootNode); element; element = ElementTraversal::next(element, rootNode)) {
407        for (unsigned i = 0; i < m_selectors.size(); ++i) {
408            if (selectorMatches(m_selectors[i], element, rootNode))
409                return element;
410        }
411    }
412    return 0;
413}
414
415Element* SelectorDataList::executeQueryFirst(Node* rootNode) const
416{
417    if (!canUseFastQuery(rootNode))
418        return executeSlowQueryFirst(rootNode);
419
420
421    const CSSSelector* selector = m_selectors[0].selector;
422    ASSERT(selector);
423
424    if (!selector->tagHistory()) {
425        // Fast path for querySelector('#id'), querySelector('.foo'), and querySelector('div').
426        // Many web developers uses querySelector with these simple selectors.
427        switch (selector->m_match) {
428        case CSSSelector::Id:
429            {
430                if (rootNode->document()->containsMultipleElementsWithId(selector->value()))
431                    break;
432                Element* element = rootNode->treeScope()->getElementById(selector->value());
433                return element && (isTreeScopeRoot(rootNode) || element->isDescendantOf(rootNode)) ? element : 0;
434            }
435        case CSSSelector::Class:
436            return findElementByClassName(rootNode, selector->value());
437        case CSSSelector::Tag:
438            return findElementByTagName(rootNode, selector->tagQName());
439        default:
440            break; // If we need another fast path, add here.
441        }
442    }
443
444    bool matchTraverseRoot;
445    Node* traverseRootNode = findTraverseRoot(rootNode, matchTraverseRoot);
446    if (!traverseRootNode)
447        return 0;
448    if (matchTraverseRoot) {
449        ASSERT(m_selectors.size() == 1);
450        ASSERT(traverseRootNode->isElementNode());
451        Element* element = toElement(traverseRootNode);
452        return selectorMatches(m_selectors[0], element, rootNode) ? element : 0;
453    }
454
455    for (Element* element = ElementTraversal::firstWithin(traverseRootNode); element; element = ElementTraversal::next(element, traverseRootNode)) {
456        if (selectorMatches(m_selectors[0], element, rootNode))
457            return element;
458    }
459    return 0;
460}
461
462SelectorQuery::SelectorQuery(const CSSSelectorList& selectorList)
463    : m_selectorList(selectorList)
464{
465    m_selectors.initialize(m_selectorList);
466}
467
468bool SelectorQuery::matches(Element* element) const
469{
470    return m_selectors.matches(element);
471}
472
473PassRefPtr<NodeList> SelectorQuery::queryAll(Node* rootNode) const
474{
475    return m_selectors.queryAll(rootNode);
476}
477
478PassRefPtr<Element> SelectorQuery::queryFirst(Node* rootNode) const
479{
480    return m_selectors.queryFirst(rootNode);
481}
482
483SelectorQuery* SelectorQueryCache::add(const AtomicString& selectors, Document* document, ExceptionState& es)
484{
485    HashMap<AtomicString, OwnPtr<SelectorQuery> >::iterator it = m_entries.find(selectors);
486    if (it != m_entries.end())
487        return it->value.get();
488
489    CSSParser parser(document);
490    CSSSelectorList selectorList;
491    parser.parseSelector(selectors, selectorList);
492
493    if (!selectorList.first() || selectorList.hasInvalidSelector()) {
494        es.throwDOMException(SyntaxError);
495        return 0;
496    }
497
498    // throw a NamespaceError if the selector includes any namespace prefixes.
499    if (selectorList.selectorsNeedNamespaceResolution()) {
500        es.throwDOMException(NamespaceError);
501        return 0;
502    }
503
504    const int maximumSelectorQueryCacheSize = 256;
505    if (m_entries.size() == maximumSelectorQueryCacheSize)
506        m_entries.remove(m_entries.begin());
507
508    OwnPtr<SelectorQuery> selectorQuery = adoptPtr(new SelectorQuery(selectorList));
509    SelectorQuery* rawSelectorQuery = selectorQuery.get();
510    m_entries.add(selectors, selectorQuery.release());
511    return rawSelectorQuery;
512}
513
514void SelectorQueryCache::invalidate()
515{
516    m_entries.clear();
517}
518
519}
520