diff --git a/corellator.py b/corellator.py deleted file mode 100644 index 009e596..0000000 --- a/corellator.py +++ /dev/null @@ -1,264 +0,0 @@ -#!/usr/bin/env python -########################################################### -# Copyright (c) 2022 Advanced Micro Devices, Inc. -########################################################### - -import sqlite3 -import json -import sys -import itertools -import argparse -from dataclasses import dataclass - -@dataclass -class Interval: - kernel: str - start: int - end: int - - def __gt__(self, other): - return self.start > other.start - -@dataclass -class Node: - earlier: 'Node' - later: 'Node' - interval: Interval - def __init__(self, interval): - self.earlier = None - self.later = None - self.interval = interval - - def __len__(self): - return ((0 if self.earlier is None else len(self.earlier)) + - (0 if self.later is None else len(self.later)) + - (0 if self.interval is None else 1)) - -def insert_interval(interval: Interval, btree: Node, start=True): - if btree.interval is None: - btree.interval = interval - else: - if start: - ts1 = interval.start - ts2 = btree.interval.start - else: - ts1 = interval.end - ts2 = btree.interval.end - - if ts1 > ts2: - if btree.later is None: - btree.later = Node(interval) - else: - insert_interval(interval, btree.later, start) - else: - if btree.earlier is None: - btree.earlier = Node(interval) - else: - insert_interval(interval, btree.earlier, start) - -def insert_intervals(intervals, btree, start=True): - if len(intervals) == 0: - return - elif len(intervals) == 1: - insert_interval(intervals[0], btree, start) - return - - middle = len(intervals) // 2 - insert_interval(intervals[middle], btree, start) - insert_intervals(intervals[:middle], btree, start) - insert_intervals(intervals[middle + 1:], btree, start) - -def find_interval(ts: int, btree: Node): - if btree is None: - return None - if ts > btree.interval.start: - if ts < btree.interval.end: - return btree.interval - else: - return find_interval(ts, btree.later) - else: - return find_interval(ts, btree.earlier) - -def load_json(path, plot, stretch_plot, skew): - print("Loading file", path) - with open(path, 'r') as fp: - data = json.load(fp) - - reg_values = {} - stretch_values = {} - interval_list = [] - - print('Finding data points') - - for row in data['traceEvents']: - if 'args' in row: - if plot in row['args']: - v = row['args'][plot] - reg_values[row['ts'] * 1000 + skew] = v - elif stretch_plot in row['args']: - stretch_values[row['ts'] * 1000 + skew] = row['args'][stretch_plot] - elif 'desc' in row['args'] and row['args']['desc'] == "KernelExecution": - name = row['name'] - start = int(row['ts']) * 1000 - dur = int(row['dur']) * 1000 - interval_list.append(Interval(name, start, start + dur)) - - return reg_values, interval_list, stretch_values - - -def load_rpd(path, plot, stretch_plot, freq, skew): - print("Loading file", path) - con = sqlite3.connect(path) - cur = con.cursor() - - reg_values = {} - stretch_values = {} - freq_values = {} - interval_list = [] - - print('Finding data points') - kerns = cur.execute("SELECT start, end, kernelName FROM kernel").fetchall() - for start, end, name in kerns: - interval_list.append(Interval(name, start, end)) - - plots = cur.execute(f"SELECT start, value FROM rocpd_monitor WHERE monitorType = '{plot}'").fetchall() - for time, value in plots: - value = int(value) - reg_values[time + skew] = value - - plots = cur.execute(f"SELECT start, value FROM rocpd_monitor WHERE monitorType = '{stretch_plot}'").fetchall() - for time, value in plots: - value = float(value) - stretch_values[time + skew] = value - - plots = cur.execute(f"SELECT start, value FROM rocpd_monitor WHERE monitorType = '{freq}'").fetchall() - for time, value in plots: - value = float(value) - freq_values[time + skew] = value - print("freq", len(plots)) - - return reg_values, interval_list, stretch_values, freq_values - -def add_arguments(parser: argparse.ArgumentParser): - parser.add_argument('--plots', default='/etc/corellator.json', help="The plot name that counts events") - parser.add_argument('input', help="A perfetto trace file") - parser.add_argument('--skew', default=0, type=int, help="A time offset applied to events (ns)") - parser.add_argument('-n', default=10, type=int, help="Show the top N results") - -def main(argv): - parser = argparse.ArgumentParser() - add_arguments(parser) - args = parser.parse_args(argv[1:]) - - with open(args.plots, 'r') as fp: - plots = json.load(fp) - event_plot = plots['counter'] - stretch_plot = plots['stretch'] - - if args.input.endswith('json'): - loader = load_json - elif args.input.endswith('rpd'): - loader = load_rpd - else: - print("Unknown input format") - return 1 - reg_values, interval_list, stretch_values, freq_values = loader(args.input, event_plot, stretch_plot, plots['freq'], args.skew) - - reg_total = sum(reg_values.values()) - - kernel_data = {} - kernel_time = {} - kernel_stretch = {} - freq_data = {} - - for interval in interval_list: - name = interval.kernel - if not name in kernel_data: - kernel_data[name] = [] - freq_data[name] = [] - kernel_time[name] = kernel_time.get(name, 0.0) + (interval.end - interval.start) - - print('Num data points', event_plot, len(reg_values)) - print('Num data points', stretch_plot, len(stretch_values)) - print('Num kernels:', len(kernel_data)) - print('Num intervals:', len(interval_list)) - print('Total', event_plot, reg_total) - - if len(reg_values) == 0 or len(kernel_data) == 0: - print('There is a problem with the trace, it is missing critical data.') - return 1 - - interval_btree = Node(None) - insert_intervals(interval_list, interval_btree, True) - - print('Btree assembled') - - no_kern = 0 - no_kern_sum = 0 - - for ts in reg_values: - interval = find_interval(ts, interval_btree) - if interval is None: - no_kern += 1 - no_kern_sum += reg_values[ts] - else: - kernel_data[interval.kernel].append(reg_values[ts]) - - stretch_times = sorted(stretch_values.keys()) - for i in range(1, len(stretch_times)): - growth = stretch_values[stretch_times[i]] - stretch_values[stretch_times[i-1]] - if growth > 0: - interval = find_interval(stretch_times[i], interval_btree) - if interval is not None: - kernel_stretch[interval.kernel] = kernel_stretch.get(interval.kernel, 0.0) + growth - - print('Data points without kernel:', no_kern, "({} {:.3f}%)".format(no_kern_sum, no_kern_sum / reg_total * 100)) - print('Data points with kernel:', len(reg_values) - no_kern, "({} {:.3f}%)".format(reg_total - no_kern_sum, (reg_total - no_kern_sum) / reg_total * 100)) - - kernel_sums = {} - for kern in kernel_data: - kernel_sums[kern] = sum(kernel_data[kern]) - - sorted_sums = sorted(kernel_sums.items(), key=lambda x: x[1], reverse=True) - - top_n = min(args.n, len(sorted_sums)) - print('Top', top_n, 'offenders (total events)') - for kernel, ksum in sorted_sums[:top_n]: - print(kernel, ksum, '{:.3f}%'.format(ksum/reg_total*100.0)) - - sorted_rate = sorted([(x[0], x[1] / kernel_time[x[0]] if kernel_time[x[0]] > 0 else 0.0) - for x in kernel_sums.items()], key=lambda x: x[1], reverse=True) - - print('Top', top_n, 'offenders (events/us)') - for kernel, rate in sorted_rate[:top_n]: - print(kernel, '{:.3f}'.format(rate * 1000.0)) - - sorted_growth = sorted(kernel_stretch.items(), key=lambda x: x[1], reverse=True) - print('Top', top_n, 'stretch growth') - for kernel, growth in sorted_growth[:top_n]: - print(kernel, '{:.3f}'.format(growth)) - - for ts in freq_values: - interval = find_interval(ts, interval_btree) - if interval is None: - pass - else: - freq_data[interval.kernel].append(freq_values[ts]) - - freq_avg = {} - for kern in freq_data: - data = freq_data[kern] - if len(data) > 0: - avg = sum(data)/len(data) - freq_avg[kern] = avg - - - print('Bottom', top_n, 'average frequency') - sorted_freq = sorted(freq_avg.items(), key=lambda x: x[1]) - for kernel, freq in sorted_freq[:top_n]: - print(kernel, '{:.0f}'.format(freq/1000.0)) - - return 0 - -if __name__ == '__main__': - sys.exit(main(sys.argv)) diff --git a/correlator.py b/correlator.py new file mode 100644 index 0000000..e3f47e0 --- /dev/null +++ b/correlator.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python +########################################################### +# Copyright (c) 2022 Advanced Micro Devices, Inc. +########################################################### + +import sqlite3 +import json +import sys +import itertools +import argparse +from dataclasses import dataclass, field + +@dataclass +class Interval: + kernel: str + start: int + end: int + + def __gt__(self, other): + return self.start > other.start + +@dataclass +class Node: + earlier: 'Node' + later: 'Node' + interval: Interval + def __init__(self, interval): + self.earlier = None + self.later = None + self.interval = interval + + def __len__(self): + return ((0 if self.earlier is None else len(self.earlier)) + + (0 if self.later is None else len(self.later)) + + (0 if self.interval is None else 1)) + +def insert_interval(interval: Interval, btree: Node, start=True): + if btree.interval is None: + btree.interval = interval + else: + if start: + ts1 = interval.start + ts2 = btree.interval.start + else: + ts1 = interval.end + ts2 = btree.interval.end + + if ts1 > ts2: + if btree.later is None: + btree.later = Node(interval) + else: + insert_interval(interval, btree.later, start) + else: + if btree.earlier is None: + btree.earlier = Node(interval) + else: + insert_interval(interval, btree.earlier, start) + +def insert_intervals(intervals, btree, start=True): + if len(intervals) == 0: + return + elif len(intervals) == 1: + insert_interval(intervals[0], btree, start) + return + + middle = len(intervals) // 2 + insert_interval(intervals[middle], btree, start) + insert_intervals(intervals[:middle], btree, start) + insert_intervals(intervals[middle + 1:], btree, start) + +def find_interval(ts: int, btree: Node): + if btree is None: + return None + if ts > btree.interval.start: + if ts < btree.interval.end: + return btree.interval + else: + return find_interval(ts, btree.later) + else: + return find_interval(ts, btree.earlier) + +def mean(items, _): + return sum(items)/len(items) + +@dataclass +class Metric: + name: str + trace_name: str + top: bool + dtype: type + cumulative: bool = False + values: dict = field(default_factory=dict) + kern_values: dict = field(default_factory=dict) + no_kern: int = 0 + total = 0 + no_kern_sum = 0 + + def load_intervals(self, intervals): + for ts in self.values: + interval = find_interval(ts, intervals) + if interval is None: + self.no_kern += 1 + if self.cumulative: + self.no_kern_sum += self.values[ts] + else: + if not interval.kernel in self.kern_values: + self.kern_values[interval.kernel] = [] + self.kern_values[interval.kernel].append(self.values[ts]) + + def summary(self): + result = f'{self.name} ({self.trace_name}): {len(self.values)} Samples, ' + if self.cumulative: + result += f' Total: {self.total} Samples w/o kernel: {self.no_kern}' + elif len(self.values) > 0: + result += f' Average: {self.total/len(self.values)}' + return result + + def report(self, note, count, rank, percent=False): + value_ranks = {} + for kern in self.kern_values: + value_ranks[kern] = rank(self.kern_values[kern], kern) + + sorted_rank = sorted(value_ranks.items(), key=lambda x: x[1], reverse=self.top) + + count = min(count, len(sorted_rank)) + + print('Top' if self.top else 'Bottom', count, f'kernels ({self.name})', note) + for kernel, rank in sorted_rank[:count]: + if percent: + print(kernel, rank, '{:.3f}%'.format(rank/self.total*100.0)) + else: + print(kernel, rank) + + return sorted_rank + +def load_rpd(path, metrics, skew): + print("Loading file", path) + con = sqlite3.connect(path) + cur = con.cursor() + + reg_values = {} + stretch_values = {} + freq_values = {} + interval_list = [] + + print('Finding data points') + kerns = cur.execute("SELECT start, end, kernelName FROM kernel").fetchall() + for start, end, name in kerns: + interval_list.append(Interval(name, start, end)) + + for metric in metrics: + plots = cur.execute(f"SELECT start, value FROM rocpd_monitor WHERE monitorType = '{metric.trace_name}'").fetchall() + for time, value in plots: + value = metric.dtype(value) + metric.values[time + skew] = value + + return interval_list + +def prepare_metrics(metrics, intervals): + for metric in metrics: + metric.total = sum(metric.values.values()) + metric.load_intervals(intervals) + +def add_arguments(parser: argparse.ArgumentParser): + parser.add_argument('--plots', default='/etc/correlator.json', help="The plot name that counts events") + parser.add_argument('input', help="A perfetto trace file") + parser.add_argument('--skew', default=0, type=int, help="A time offset applied to events (ns)") + parser.add_argument('-n', default=10, type=int, help="Show the top N results") + +def main(argv): + parser = argparse.ArgumentParser() + add_arguments(parser) + args = parser.parse_args(argv[1:]) + + with open(args.plots, 'r') as fp: + plots = json.load(fp) + + pcc_events = Metric("PCC Events", plots['counter'], True, int, True) + clock_stretch = Metric("Clock Stretch", plots['stretch'], True, float) + gfx_frequency = Metric("GFX Frequency", plots['freq'], False, float) + + metrics = [pcc_events, clock_stretch, gfx_frequency] + + interval_list = load_rpd(args.input, metrics, args.skew) + + print('Assemble btree') + interval_btree = Node(None) + insert_intervals(interval_list, interval_btree, True) + + print('Prepare Metrics') + prepare_metrics(metrics, interval_btree) + + kernel_time = {} + + for interval in interval_list: + name = interval.kernel + kernel_time[name] = kernel_time.get(name, 0.0) + (interval.end - interval.start) + + print('Num kernels:', len(kernel_time)) + print('Num intervals:', len(interval_list)) + + for metric in metrics: + print(metric.summary()) + + if any([len(x.values) == 0 for x in metrics]): + print('There is a problem with the trace, it is missing critical data.') + return 1 + + def event_rate(values, kern): + return 1000.0 * sum(values) / kernel_time[kern] if kernel_time[kern] > 0 else 0.0 + + pcc_events.report('Count', args.n, lambda x, _: sum(x), True) + pcc_events.report('Rate (per us)', args.n, event_rate) + clock_stretch.report('Average', args.n, mean) + gfx_frequency.report('Average', args.n, mean) + + return 0 + +if __name__ == '__main__': + sys.exit(main(sys.argv)) diff --git a/power_trace.sh b/power_trace.sh index 743c6be..2e83008 100755 --- a/power_trace.sh +++ b/power_trace.sh @@ -31,7 +31,7 @@ cat << EOF > "$tmpdir/iree-benchmark-module" set -xe sudo bash "$tmpdir/doas-root" "\$@" ret=$? -python3 "$script_dir/corellator.py" trace.rpd || echo "Returned $?" +python3 "$script_dir/correlator.py" trace.rpd || echo "Returned $?" exit $ret EOF