Skip to content

Instantly share code, notes, and snippets.

@JonathanRaiman
Last active April 29, 2018 07:27
Show Gist options
  • Save JonathanRaiman/032d0854e7b88011c0836b2be9ed60b5 to your computer and use it in GitHub Desktop.
Save JonathanRaiman/032d0854e7b88011c0836b2be9ed60b5 to your computer and use it in GitHub Desktop.
Access Pattern inference
"""
Access Pattern Search
---------------------
Code for simulating the effect of searching for the right access pattern in
a CUDA Kernel computation directed acyclic graph.
The key idea is to have every node in the computation graph return an object
representing "for loops" that can be optionally parallelized using blocks
or threads (followed by syncs).
In the case of blocks we must guarantee
that data being passed from one side of the graph to another is available
in the desired block.
In the case of thread we must ensure that concurrent pieces of work do
not request control of the threads, or else one will be masked by the
thread conditionals of the other.
Our goal is then to give to each "for loop" a type of parallelism (threadIdx,
blockIdx, null (no parallelism), or threadIdxTimesBlockIdx (divide up the work
of the loop among all blocks and threads)). We can pose this as a graph search
problem where we each node is a valid set of assignments, and each edge is
an assignment of parallelism to a "for loop".
We are currently using a scoring function/heuristic that is admissible for A* (assumes
the best possible parallelization of the remaining compute pieces).
In this script we show how to generate the right assignments for a fused softmax kernel,
some cumulative sums, and several nested reductions and scans.
@author Jonathan Raiman
"""
import copy
import heapq
import time
import numpy as np
class Node(object):
def __init__(self, name, parent=None):
self.name = name
self.parent = parent
self.children = []
if parent:
self.parent.children.append(self)
def nb_children(current_node):
return sum(map(nb_children, current_node.children)) + 1
def print_tree(current_node, indent='', last='updown'):
out = ""
size_branch = {child: nb_children(child) for child in current_node.children}
""" Creation of balanced lists for "up" branch and "down" branch. """
up = sorted(current_node.children, key=lambda node: size_branch[node])
down = []
while up and sum(size_branch[node] for node in down) < sum(size_branch[node] for node in up):
down.append(up.pop())
""" Printing of "up" branch. """
for child in up:
next_last = 'up' if up.index(child) is 0 else ''
next_indent = '{0}{1}{2}'.format(indent, ' ' if 'up' in last else '│', ' ' * len(current_node.name))
out += print_tree(child, next_indent, next_last)
""" Printing of current node. """
if last == 'up':
start_shape = '┌'
elif last == 'down':
start_shape = '└'
elif last == 'updown':
start_shape = ' '
else:
start_shape = '├'
if up:
end_shape = '┤'
elif down:
end_shape = '┐'
else:
end_shape = ''
out += '{0}{1}{2}{3}\n'.format(indent, start_shape, current_node.name, end_shape)
""" Printing of "down" branch. """
for child in down:
next_last = 'down' if down.index(child) is len(down) - 1 else ''
next_indent = '{0}{1}{2}'.format(indent, ' ' if 'down' in last else '│', ' ' * len(current_node.name))
out += print_tree(child, next_indent, next_last)
return out
class Index(object):
def __init__(self, loop=None):
self._loop = loop
def __str__(self):
return "Index()"
def __repr__(self):
return str(self)
class ForLoopFusionGroup(object):
pass
NULL = 0
THREADIDX = 1
BLOCKIDX = 2
BLOCKIDXTIMESTHREADIDX = 3
choice2name = ["null", "threadIdx", "blockIdx", "blockIdxTimesThreadIdx"]
class ForLoop(object):
def __init__(self, choices):
self._choices = choices
self._children = []
self._arg = None
self.fusion_group = None
def __str__(self):
return str(" / ".join([choice2name[c] for c in self._choices]))
def __repr__(self):
return str(self)
def add_child(self, child):
self._children.append(child)
def iterator(self):
return [Index(loop=self)]
@staticmethod
def any(ndim=1):
return ForLoop((NULL, THREADIDX, BLOCKIDX, BLOCKIDXTIMESTHREADIDX))
def name_loop(self, namer):
return "({}) {} [g={}]".format(self._arg.name() if self._arg is not None else "Root",
str(self),
namer(self.fusion_group) if self.fusion_group is not None else "n/a")
def pretty_print(self):
namer = CountNamer()
return pretty_print_tree([self], lambda x: x.name_loop(namer))
def connect_loop_chain(loops):
for i in range(len(loops) - 1):
if loops[i + 1] not in loops[i]._children:
loops[i].add_child(loops[i + 1])
class SymbolTable(object):
def __init__(self):
self._arrays = []
self._temps = {}
def array_index(self, obj):
if obj not in self._arrays:
self._arrays.append(obj)
return self._arrays.index(obj)
def store_into_temporary(self, node):
if isinstance(node, Buffer):
return
nhash = node_hash(node, self)
if nhash not in self._temps:
self._temps[nhash] = Buffer(node.shape)
def has_temporary(self, node):
if isinstance(node, Buffer):
return False
return node_hash(node, self) in self._temps
def get_temporary(self, node):
return self._temps[node_hash(node, self)]
class Expression(object):
def __init__(self, shape, arguments):
self._shape = shape
self._arguments = arguments
@property
def shape(self):
return self._shape
@property
def ndim(self):
return len(self._shape)
def arguments(self):
return self._arguments
def set_arguments(self, arguments):
self._arguments = arguments
def name(self):
return type(self).__name__
def __repr__(self):
return self.name()
def chainable(self):
return True
def pretty_print(self):
return pretty_print_tree([self], lambda x: str(x), "_arguments")
def access_patterns(self, parent_loop, parent_iterator, constraints):
child_for_loops, child_iterators = self.child_for_loop(parent_iterator, constraints)
if len(child_for_loops) > 0 and len(child_for_loops[0]) > 0:
print(self.name())
for arg, for_loops, iterator in zip(self._arguments, child_for_loops, child_iterators):
for el in for_loops:
el._arg = self
extended_for_loop = [parent_loop] + for_loops
connect_loop_chain(extended_for_loop)
arg.access_patterns(parent_loop=extended_for_loop[-1], parent_iterator=iterator, constraints=constraints)
return parent_loop
def node_data(self, symbol_table):
return tuple([arg.node_data(symbol_table) for arg in self.arguments()])
def node_type(self):
return (type(self), tuple([arg.node_type() for arg in self.arguments()]))
def update_symbol_table(self, symbol_table):
pass
class Buffer(Expression):
def __init__(self, shape):
self._shape = shape
self._arguments = []
def child_for_loop(self, parent_iterator, constraints):
return [], []
def node_data(self, symbol_table):
return symbol_table.array_index(self)
class AxisReduce(Expression):
def __init__(self, functor, array):
self._functor = functor
self._shape = array.shape[:-1]
self._arguments = [array]
def name(self):
return self._functor
def child_for_loop(self, parent_iterator, constraints):
reducer_loop = ForLoop((NULL, THREADIDX))
child_iterator = parent_iterator + reducer_loop.iterator()
return [[reducer_loop]], [child_iterator]
class ExpandDims(Expression):
def __init__(self, array):
self._shape = list(array.shape) + [1]
self._arguments = [array]
def child_for_loop(self, parent_iterator, constraints):
child_iterator = parent_iterator[:-1]
return [[]], [child_iterator]
def update_symbol_table(self, symbol_table):
symbol_table.store_into_temporary(self._arguments[0])
class ElementWise(Expression):
def __init__(self, functor, arrays):
self._functor = functor
self._shape = arrays[0].shape
self._arguments = arrays.copy()
def name(self):
return self._functor
def child_for_loop(self, parent_iterator, constraints):
return ([[] for _ in range(len(self._arguments))],
[parent_iterator for _ in range(len(self._arguments))])
class AxisScan(Expression):
def __init__(self, array):
self._shape = array.shape
self._arguments = [array]
def child_for_loop(self, parent_iterator, constraints):
scan_loop = ForLoop((NULL, THREADIDX))
return [[scan_loop]], [parent_iterator[:-1] + scan_loop.iterator()]
def chainable(self):
return False
def collect_loops(arg, axes):
axes.append(arg)
for arg in arg._children:
collect_loops(arg, axes)
def is_valid(*, choice, assignment, depth, constraints, verbose=False):
for idx, constraint in enumerate(constraints):
if constraint.child == depth:
mismatches = ((choice == BLOCKIDX and
assignment._slots[constraint.parent]._value in (THREADIDX, BLOCKIDXTIMESTHREADIDX)) or
(choice == BLOCKIDXTIMESTHREADIDX and
assignment._slots[constraint.parent]._value in (BLOCKIDX, THREADIDX)))
if mismatches:
if verbose:
print("Cannot use {} because of a constraint #{} tying this slot to slot with value {}".format(
choice, idx, assignment._slots[constraint.parent]._value))
return False
# do general cleanup on tied fusions:
if depth == len(assignment._slots) - 1:
for idx, constraint in enumerate(constraints):
parent_fusion_quantity, parent_fusion_type, parent_remainder = assignment.get_fusion_group(
constraint.parent)
if constraint.child == depth:
child_fusion_quantity, child_fusion_type, child_remainder = assignment.get_fusion_group(
constraint.child, choice=choice)
child_fusion_quantity += 1
child_remainder = 0
else:
child_fusion_quantity, child_fusion_type, child_remainder = assignment.get_fusion_group(
constraint.child, choice=parent_fusion_type)
if child_fusion_type in (BLOCKIDX, BLOCKIDXTIMESTHREADIDX):
if parent_fusion_quantity != child_fusion_quantity:
if verbose:
print(("Cannot use {} because of a mismatched fusion size between child with value {} and "
"parent with value {} under constraint #{}").format(
choice, child_fusion_type, parent_fusion_type, idx))
return False
if parent_fusion_type != child_fusion_type:
if verbose:
print(("Cannot use {} because of a mismatched fusion type child with value {} and "
"parent with value {} under constraint #{}").format(
choice, child_fusion_type, parent_fusion_type, idx))
return False
if depth == 0 or choice == NULL:
return True
all_parents = assignment.get_parent_assignments(depth=depth)
all_assignments = assignment.get_assignments(depth=depth)
group_assignments = assignment.get_group_assignments(depth)
if len(group_assignments) == 1 and choice in group_assignments:
return True
blockIdxTimesthreadIdx_ok = (choice == BLOCKIDXTIMESTHREADIDX and
choice not in all_assignments and
THREADIDX not in all_assignments and
BLOCKIDX not in all_assignments)
if blockIdxTimesthreadIdx_ok:
return True
elif choice == BLOCKIDX and choice not in all_assignments and BLOCKIDXTIMESTHREADIDX not in all_assignments:
return True
elif choice == THREADIDX and choice not in all_parents and BLOCKIDXTIMESTHREADIDX not in all_assignments:
return True
if verbose:
print(("Cannot use {} because this choice already shows up in a parent (all_assignments={})").format(
choice, all_assignments))
return False
class ThreadAssignment(object):
def __init__(self, solution):
self._solution = solution
self._children = []
self._depth = 0
def create_branch(self, choice, depth):
copied_solution = TreeSlots(self._solution._slots.copy())
copied_solution._slots[depth] = copied_solution._slots[depth].copy()
copied_solution._slots[depth]._value = choice
child_solution = ThreadAssignment(copied_solution)
child_solution._depth = self._depth + 1
return child_solution
def search_branch(self, choice, depth):
self._children.append(self.create_branch(choice, depth))
def pretty_print(self):
return pretty_print_tree([self], lambda x: x._solution._slots[x._depth - 1]._value if x._depth > 0 else "")
def test_solution(self, choices, constraints, axes):
branch = self
success = True
for depth, choice in enumerate(choices):
if choice not in axes[depth]._choices:
print("Choice {} is not available at depth = {}".format(choice, depth))
if is_valid(choice=choice,
assignment=branch._solution,
depth=depth,
constraints=constraints,
verbose=True):
branch = branch.create_branch(choice, depth=depth)
print([v._value for v in branch._solution._slots])
print("assigned {}".format(choices[:depth + 1]))
else:
print("could not assign {} with {} to {}".format(choice, choices[:depth], axes[depth]._arg))
success = False
break
if success:
print(estimate_solution_duration(branch._solution))
def complete(self):
return self._depth == len(self._solution._slots)
def __lt__(self, other):
return True
def __gt__(self, other):
return False
def estimate_parallel_time(*, blockIdx_used, threadIdx_used, leaf_parallel,
blockIdx_usage, threadIdx_usage, blockIdxTimesthreadIdx_usage):
# 2 ^ parallel_time
parallel_time = 1
if blockIdx_used:
parallel_time += 8
if threadIdx_used:
parallel_time += 8
if blockIdx_used and threadIdx_used:
parallel_time += blockIdx_usage * 0.2
parallel_time += threadIdx_usage * 0.2
parallel_time += blockIdxTimesthreadIdx_usage * 0.2
if leaf_parallel:
parallel_time += 0.5
return parallel_time
def estimate_solution_duration(solution):
groups = solution.get_disjoint_assignments()
total = 0.0
for group in groups:
parallel_dimensions = 0
threadIdx_used = False
blockIdx_used = False
blockIdx_usage = 0
threadIdx_usage = 0
blockIdxTimesthreadIdx_usage = 0
leaf_parallel = False
total_work = 0
for slot in group:
el = slot._value
if el != NULL:
parallel_dimensions += 1
if el == THREADIDX or el == BLOCKIDXTIMESTHREADIDX:
threadIdx_used = True
if el == BLOCKIDX or el == BLOCKIDXTIMESTHREADIDX:
blockIdx_used = True
if el == THREADIDX:
threadIdx_usage += 1
elif el == BLOCKIDX:
blockIdx_usage += 1
elif el == BLOCKIDXTIMESTHREADIDX:
blockIdxTimesthreadIdx_usage += 1
if el is not None:
# logarithm base-2 of total work
total_work += 16
parallelized_work = parallel_dimensions * 16
non_parallelized_work = total_work - parallelized_work
if group[-1]._value != NULL:
leaf_parallel = True
parallel_time = estimate_parallel_time(blockIdx_usage=blockIdx_usage,
threadIdx_usage=threadIdx_usage,
threadIdx_used=threadIdx_used,
blockIdx_used=blockIdx_used,
blockIdxTimesthreadIdx_usage=blockIdxTimesthreadIdx_usage,
leaf_parallel=leaf_parallel)
# you can effectively divide by threads and blocks the expected work:
parallelized_work_visible = max(0, parallelized_work - parallel_time)
total += non_parallelized_work + parallelized_work_visible
return total
def heuristic_cost_estimate(solution):
# num remaining dimensions:
groups = solution.get_disjoint_assignments()
total = 0.0
for group in groups:
parallelized_work = 0
parallelized_usage = 0
for el in group:
if el._value is None:
parallelized_work += 16
parallelized_usage += 1
parallel_time = estimate_parallel_time(blockIdx_usage=0,
threadIdx_usage=0,
threadIdx_used=True,
blockIdx_used=True,
blockIdxTimesthreadIdx_usage=parallelized_usage,
leaf_parallel=group[-1] is None)
total += max(0, parallelized_work - parallel_time)
return total
class IndexCouple(object):
def __init__(self, *, parent, child):
self.parent = parent
self.child = child
def __str__(self):
return "IndexCouple({} >= {})".format(str(self.parent._loop._arg.name()), str(self.child._loop._arg.name()))
def __repr__(self):
return str(self)
def couple_iterators(*, parent, child):
return [IndexCouple(parent=p_index, child=c_index) for p_index, c_index in zip(parent, child)]
def allow_loop_fusion(loops):
fusion_group = ForLoopFusionGroup()
for loop in loops:
loop.fusion_group = fusion_group
def build_ndim_loop_iterator(ndim):
loop = [ForLoop.any() for i in range(ndim)]
allow_loop_fusion(loop)
loop_iterator = [i for l in loop for i in l.iterator()]
return loop, loop_iterator
class Assignment(Expression):
def __init__(self, left, right):
self._arguments = [left, right]
self._shape = left.shape
def child_for_loop(self, parent_iterator, constraints):
loop, loop_iterator = build_ndim_loop_iterator(ndim=len(self._arguments[0].shape))
constraints.extend(couple_iterators(parent=parent_iterator, child=loop_iterator))
return [loop, loop], [parent_iterator, loop_iterator]
def name(self):
return "Assignment({})".format(self._arguments[1].name())
def access_patterns(self, constraints, parent_loop=None, parent_iterator=None):
if parent_loop is None:
parent_loop, parent_iterator = build_ndim_loop_iterator(ndim=len(self._arguments[0].shape))
for l in parent_loop:
l._arg = self
connect_loop_chain(parent_loop)
self._arguments[0].access_patterns(parent_loop=parent_loop[-1],
parent_iterator=parent_iterator,
constraints=constraints)
self._arguments[1].access_patterns(parent_loop=parent_loop[-1],
parent_iterator=parent_iterator,
constraints=constraints)
return parent_loop[0]
else:
assert parent_iterator is not None, \
"must have parent_loop and parent_iterator be both None, or both present."
return Expression.access_patterns(self,
parent_loop=parent_loop,
parent_iterator=parent_iterator,
constraints=constraints)
class Slot(object):
def __init__(self, value, fusion_group):
self._value = value
self._fusion_group = fusion_group
self._children = []
self._parents = []
def add_child(self, child):
self._children.append(child)
def add_parent(self, parent):
self._parents.append(parent)
def copy(self):
slot = Slot(self._value, self._fusion_group)
slot._children = self._children
slot._parents = self._parents
return slot
def parents(self):
return self._parents
def children(self):
return self._children
class TreeSlots(object):
def __init__(self, slots):
self._slots = slots
def copy(self):
return TreeSlots([slot.copy() for slot in self._slots])
def get_child_assignments(self, depth):
values = set()
if self._slots[depth]._value is not None:
values.add(self._slots[depth]._value)
for child in self._slots[depth]._children:
for val in self.get_child_assignments(child):
values.add(val)
return values
def get_parent_assignments(self, depth, visited=None):
values = set()
if visited is None:
visited = {depth}
if self._slots[depth]._value is not None:
values.add(self._slots[depth]._value)
for parent in self._slots[depth]._parents:
if parent not in visited:
visited.add(parent)
for val in self.get_parent_assignments(parent, visited=visited):
values.add(val)
return values
def get_parents(self, depth, visited=None):
values = set()
if visited is None:
visited = {depth}
if self._slots[depth]._value is not None:
values.add(self._slots[depth]._arg)
for parent in self._slots[depth]._parents:
if parent not in visited:
visited.add(parent)
for val in self.get_parents(parent, visited=visited):
values.add(val)
return values
def get_assignments(self, depth):
return self.get_parent_assignments(depth=depth) | self.get_child_assignments(depth=depth)
def get_group_assignments(self, depth):
fusion_group = self._slots[depth]._fusion_group
return set([v._value for v in self._slots if v._value is not None and v._fusion_group == fusion_group])
def get_fusion_group(self, depth, choice=None):
fusion_group = self._slots[depth]._fusion_group
fusion_type = choice if self._slots[depth]._value is None else self._slots[depth]._value
remainder = 0
count = 0
for v in self._slots:
if v._fusion_group == fusion_group:
if v._value is None:
remainder += 1
elif v._value == fusion_type:
count += 1
return count, fusion_type, remainder
def get_disjoint_assignments(self):
groups = []
group_idx = [None for _ in range(len(self._slots))]
for depth in range(len(self._slots)):
if group_idx[depth] is None:
group_idx[depth] = len(groups)
groups.append([])
groups[group_idx[depth]].append(depth)
for child in self._slots[depth]._children:
group_idx[child] = group_idx[depth]
return [[self._slots[idx] for idx in group] for group in groups]
def _convert_to_tree(root, objects, name, childkey):
for obj in objects:
root.children.append(Node(name(obj)))
for obj, node in zip(objects, root.children):
_convert_to_tree(node, getattr(obj, childkey), name=name, childkey=childkey)
def pretty_print_tree(objects, name, childkey="_children"):
roots = []
child2parent = {}
for obj in objects:
for child in getattr(obj, childkey):
if child not in child2parent:
child2parent[child] = []
child2parent[child].append(obj)
for obj in objects:
if obj not in child2parent:
roots.append(obj)
assert(len(roots) > 0), "no root to the objects..."
root = Node("root")
_convert_to_tree(root, roots, name=name, childkey=childkey)
return print_tree(root)
def empty_solution(axes):
for axis in axes:
if axis.fusion_group is None:
axis.fusion_group = ForLoopFusionGroup()
groups = list(set([axis.fusion_group for axis in axes if axis.fusion_group is not None]))
slots = [Slot(None, groups.index(axis.fusion_group)) for axis in axes]
tree = TreeSlots(slots)
for axis_idx, (axis, slot) in enumerate(zip(axes, slots)):
for child in axis._children:
# disconnect assignments in the graph:
if child._arg == axis._arg or not isinstance(child._arg, Assignment):
index = axes.index(child)
slot.add_child(index)
slots[index].add_parent(axis_idx)
return tree
class CountNamer(object):
def __init__(self):
self._visited = {}
def __call__(self, obj):
if obj in self._visited:
return self._visited[obj]
name = str(len(self._visited))
self._visited[obj] = name
return name
class IndexConstraint(object):
def __init__(self, *, parent, child):
self.parent = parent
self.child = child
def convert_constraints(constraints, axes):
index_constraints = []
for constraint in constraints:
i = axes.index(constraint.parent._loop)
j = axes.index(constraint.child._loop)
assert i != -1
assert j != -1
index_constraints.append(IndexConstraint(parent=i, child=j))
return index_constraints
def astar_search_pattern(pattern, constraints):
axes = []
collect_loops(pattern, axes)
namer = CountNamer()
list(map(namer, axes))
print("len(constraints) = ", len(constraints))
print(pretty_print_tree(axes, name=namer))
root = ThreadAssignment(empty_solution(axes))
slot_constraints = convert_constraints(constraints, axes)
open_set = [(heuristic_cost_estimate(root._solution), root)]
closed_set = []
while len(open_set) > 0:
h, current = heapq.heappop(open_set)
if current.complete():
msg = "Top solution (explored {} nodes):".format(len(closed_set))
print(msg)
print("=" * len(msg))
print([choice2name[sol._value] for sol in current._solution._slots],
estimate_solution_duration(current._solution))
# for val, el in open_set:
# print(val, [choice2name[sol._value] for sol in el._solution._slots if sol._value is not None])
return current
closed_set.append(current)
# for each edge of this solution:
depth = current._depth
for choice in axes[depth]._choices:
if is_valid(choice=choice,
assignment=current._solution,
depth=depth,
constraints=slot_constraints):
new_branch = current.create_branch(choice, depth)
heapq.heappush(open_set, (estimate_solution_duration(new_branch._solution) +
heuristic_cost_estimate(new_branch._solution),
new_branch))
return None
def search_pattern(pattern, constraints):
axes = []
collect_loops(pattern, axes)
print("len(axes) = {}".format(len(axes)))
namer = CountNamer()
list(map(namer, axes))
print(pretty_print_tree(axes, name=namer))
root = ThreadAssignment(empty_solution(axes))
slot_constraints = convert_constraints(constraints, axes)
partial_solutions = [root]
solutions = []
while len(partial_solutions) > 0:
search_tree = partial_solutions.pop()
depth = search_tree._depth
if depth < len(axes):
for choice in axes[depth]._choices:
if is_valid(choice=choice,
assignment=search_tree._solution,
depth=depth,
constraints=slot_constraints):
search_tree.search_branch(choice, depth=depth)
partial_solutions.extend(search_tree._children)
else:
solutions.append(search_tree)
# print("Solution Space:")
# print("===============")
# print(root.pretty_print())
solution = min(solutions, key=lambda x: estimate_solution_duration(x._solution))
msg = "Top solution among {}:".format(len(solutions))
print(msg)
print("=" * len(msg))
print([choice2name[sol._value] for sol in solution._solution._slots],
estimate_solution_duration(solution._solution))
return solution
def exp(x):
return ElementWise("functor::exp", [x])
def sub(a, b):
return ElementWise("functor::subtract", [a, b])
def div(a, b):
return ElementWise("functor::div", [a, b])
def axis_max(x, keepdims=False):
x = AxisReduce("max", x)
if keepdims:
x = ExpandDims(x)
return x
def axis_sum(x, keepdims=False):
x = AxisReduce("sum", x)
if keepdims:
x = ExpandDims(x)
return x
def cast(x, dtype):
return ElementWise("functor::cast_{}".format(dtype.__name__), [x])
def softmax(x):
exped = exp(sub(x, axis_max(x, True)))
return div(exped, axis_sum(exped, True))
def _wrap_if_not_chainable(x):
if isinstance(x, Assignment):
_wrap_if_not_chainable(x._arguments[0])
_wrap_if_not_chainable(x._arguments[1])
return x
for child in x._arguments:
_wrap_if_not_chainable(child)
x._arguments = [Assignment(Buffer(child.shape), child) if not child.chainable() else child
for child in x._arguments]
return x
def wrap_assign(x):
root = Assignment(Buffer(x.shape), x)
_wrap_if_not_chainable(root)
return root
def node_hash(x, symbol_table):
return (x.node_data(symbol_table), x.node_type())
def subexpression_elimination(node, symbol_table, node_occurences=None):
if node_occurences is None:
node_occurences = {}
nhash = node_hash(node, symbol_table)
if nhash not in node_occurences:
node_occurences[nhash] = 1
else:
node_occurences[nhash] += 1
# make this assign somewhere...
if node_occurences[nhash] > 1:
symbol_table.store_into_temporary(node)
for arg in node.arguments():
subexpression_elimination(arg, symbol_table, node_occurences)
def update_symbol_table(node, symbol_table):
node.update_symbol_table(symbol_table)
for arg in node.arguments():
update_symbol_table(arg, symbol_table)
def replace_with_temporaries(node, symbol_table, visited=None, computed_temp=None):
if computed_temp is None:
computed_temp = {}
if visited is None:
visited = set()
visited.add(node)
for arg in node.arguments():
if arg in visited:
print("visited", arg)
continue
replace_with_temporaries(arg, symbol_table, visited, computed_temp)
new_args = []
for arg in node.arguments():
if symbol_table.has_temporary(arg):
nhash = node_hash(arg, symbol_table)
if nhash not in computed_temp:
computed_temp[nhash] = Assignment(symbol_table.get_temporary(arg), arg)
new_args.append(computed_temp[nhash])
else:
new_args.append(arg)
node.set_arguments(new_args)
return node
def present_search(node):
st = SymbolTable()
update_symbol_table(node, st)
subexpression_elimination(node, st)
node = replace_with_temporaries(node, st)
node = wrap_assign(node)
print(node.pretty_print())
constraints = []
patterns = node.access_patterns(constraints=constraints)
print(patterns.pretty_print())
t0 = time.time()
best_pattern1 = astar_search_pattern(patterns, constraints)
t1 = time.time()
t2 = time.time()
best_pattern2 = search_pattern(patterns, constraints)
t3 = time.time()
assert([v._value for v in best_pattern2._solution._slots] == [v._value for v in best_pattern1._solution._slots])
print("A* took {}s, Exhaustive took {}s".format(t1 - t0, t3 - t2))
best_pattern = best_pattern1
axes = []
collect_loops(patterns, axes)
for axis, slot in zip(axes, best_pattern._solution._slots):
axis._choices = [slot._value]
print(patterns.pretty_print())
def show_add():
a = ElementWise("add", [ExpandDims(AxisReduce("max", Buffer((20, 20)))), Buffer((20, 1))])
present_search(a)
def show_nested_reduce():
a = AxisReduce("max", ElementWise("add", [ExpandDims(AxisReduce("max", Buffer((20, 20)))),
Buffer((20, 1))]))
present_search(a)
def show_scan_reduce():
a = AxisReduce("max", AxisScan(ElementWise("add", [AxisScan(ElementWise("tanh", [Buffer((20, 20, 2, 2))])),
ElementWise("relu", [Buffer((20, 20, 2, 2))])])))
present_search(a)
def show_softmax():
buff = Buffer((2, 5))
buff = cast(buff, np.float32)
fused_softmax = softmax(buff)
present_search(fused_softmax)
def show_scan_scan():
buff = Buffer((2, 5))
scan_scan = AxisScan(AxisScan(buff))
present_search(scan_scan)
if __name__ == "__main__":
# show_scan_scan()
# show_add()
# show_nested_reduce()
# show_scan_reduce()
show_softmax()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment