Skip to content

Instantly share code, notes, and snippets.

@ankitkv
Created May 22, 2014 14:54
Show Gist options
  • Save ankitkv/c4a679c4fdef4e465621 to your computer and use it in GitHub Desktop.
Save ankitkv/c4a679c4fdef4e465621 to your computer and use it in GitHub Desktop.
A simple library to write crawlers. You need to inherit a crawler type and override the 'parse' method. Call self.add_url for URL's to add to the crawl queue.
"""Base framework for simple crawlers"""
__author__ = "Ankit Vani"
__version__ = "0.1"
import bs4
import csv
import httplib
import json
import sys
import urllib2
import urlparse
USER_AGENT = "LearnerCrawler/0.1 ([email protected])"
class NoUnvisitedError(Exception): pass
class MaxDepthError(Exception): pass
class RecordLimitError(NoUnvisitedError): pass
def _print_debug(s):
print >> sys.stderr, s
class BaseCrawler(object):
def __init__(self, seed=None, retries=2, depth=-1, user_agent=USER_AGENT,
debug=False, *args, **kwargs):
if debug:
self.debug = _print_debug
else:
self.debug = lambda s: None
self.debug_mode = debug
self.unvisited = set()
self.visited = set()
self.failed = set()
self.depth = depth
self.current_depth = 0
self.retries = retries
if seed:
if not isinstance(seed, list):
seed = [seed]
for url in list(seed):
self.add_url(url)
self.user_agent = user_agent
def fix_url(self, url, src=None):
#self.debug('Fixing url: %s' % url)
url = urlparse.urljoin(src, unicode(url))
tokens = urlparse.urlsplit(url)
url = urlparse.urlunsplit(tokens[:4] + ('',))
if url[-1] == '/': url = url[:-1]
#self.debug('Fixed to: %s' % url)
return url
def is_valid_url(self, url):
return True
def add_url(self, url, src=None, unvisited=()):
if url:
url = self.fix_url(url, src)
if self.is_valid_url(url) and not (url in self.visited) \
and not (url in unvisited):
self.debug('Adding to unvisited: %s' % url)
self.unvisited.add(unicode(url))
def _add_failed(self, url):
self.debug('Adding to failed: %s' % url)
self.failed.add(url)
def parse(self, url, data, unvisited):
self.debug('Got data for: %s' % url)
return True
def is_valid_response(self, response):
"""Determine if a response is valid to be parsed by parse()"""
return response.headers['Content-Type'].find('text/html') != -1
def visit(self, url, unvisited):
self.debug('Adding to visited: %s' % url)
self.visited.add(url)
data = ''
retries = self.retries
while retries > 0:
retries -= 1
try:
req = urllib2.Request(
url, headers={'User-Agent' : self.user_agent})
resp = urllib2.urlopen(req)
if self.is_valid_response(resp):
data = resp.read()
else:
self.debug('Not an valid response for crawling')
resp.close()
except (ValueError, urllib2.URLError, urllib2.HTTPError,
httplib.IncompleteRead):
pass
if data:
ret = self.parse(url, data, unvisited)
if ret is not False: break
self.debug('Retrying %s ...' % url)
if retries == 0:
self._add_failed(url)
def crawl_next(self):
if not self.unvisited:
raise NoUnvisitedError, 'finished crawling'
if self.depth != -1 and self.current_depth >= self.depth:
raise MaxDepthError, 'maximum depth reached'
else:
self.current_depth += 1
self.debug('Iterating unvisited ...')
unvisited = self.unvisited.copy()
self.unvisited.clear()
for url in unvisited:
self.visit(url, unvisited)
def print_failed(self):
pass
def print_report(self):
pass
def crawl(self):
try:
while True: self.crawl_next()
except (NoUnvisitedError, MaxDepthError):
pass
self.print_failed()
self.print_report()
class RecordCrawler(BaseCrawler):
def __init__(self, seed=None, output='csv', retries=2, depth=-1,
record_limit=100000, *args, **kwargs):
super(RecordCrawler, self).__init__(
seed, retries, depth, *args, **kwargs)
self.record_count = 0
self.record_limit = record_limit
if callable(output):
self.output_func = output
else:
if hasattr(self, '%s_init' % output):
getattr(self, '%s_init' % output)(
kwargs.get('output_arg', None))
self.output_func = getattr(self, '%s_output' % output)
def add_record(self, record):
self.debug('Adding record %d: %s' % (self.record_count+1, record))
self.output_func(record)
self.record_count += 1
def visit(self, url, unvisited):
if self.record_count >= self.record_limit:
raise RecordLimitError, 'record limit reached'
else:
super(RecordCrawler, self).visit(url, unvisited)
def csv_init(self, keys):
self.csv_writer = csv.DictWriter(sys.stdout, keys)
def csv_output(self, record):
self.csv_writer.writerow({k:(v and v or '').encode('utf-8')
for k,v in record.items()})
def json_output(self, record):
print json.dumps(self.record)
def readable_output(self, record):
for k,v in record.items():
print "%s %s" % (k.ljust(15), v)
print
class SitemapRecordCrawler(RecordCrawler):
def __init__(self, html_crawler, *args, **kwargs):
super(SitemapRecordCrawler, self).__init__(*args, **kwargs)
self.html_crawler = html_crawler
self.html_visited = set()
def is_valid_response(self, response):
"""Determine if a response is valid to be parsed by parse()"""
return response.headers['Content-Type'].find('/xml') != -1
def parse(self, url, data, unvisited):
super(SitemapRecordCrawler, self).parse(url, data, unvisited)
soup = bs4.BeautifulSoup(data, 'xml')
try:
top = soup.contents[0].name
getattr(self, 'parse_%s' % top)(url, data, unvisited, soup)
except (IndexError, TypeError):
return False
return True
def parse_sitemapindex(self, url, data, unvisited, soup):
for link in soup.find_all('loc'):
self.add_url(link.string.strip(), url, unvisited)
def parse_urlset(self, url, data, unvisited, soup):
seed = [link.string.strip() for link in soup.find_all('loc')]
crawler = self.html_crawler(seed=seed, output=self.output_func,
retries=self.retries, record_limit=self.record_limit,
debug=self.debug_mode)
crawler.visited = self.html_visited
crawler.failed = self.failed
crawler.record_count = self.record_count
crawler.crawl()
def print_failed(self):
if self.failed:
print >> sys.stderr, "Failed URLs:"
for url in self.failed:
print >> sys.stderr, " %s" % url
def print_report(self):
print >> sys.stderr, 'Added %d results.' % self.record_count
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment