-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Abstracted Replay Data Parsers #119
base: master
Are you sure you want to change the base?
Changes from 9 commits
bffd088
1a78352
3eb8869
1bdf502
d31a9e9
ad28d78
09e3274
6c6caf4
bf4ad86
e642f72
a25dff6
5dd050b
6fa4d1e
9969d86
b65ed0d
677ec46
e9e4924
9b62b47
ea30f55
be6cda9
c72f8a9
6aa04aa
14d732d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,97 +42,42 @@ | |
from s2clientprotocol import common_pb2 as sc_common | ||
from s2clientprotocol import sc2api_pb2 as sc_pb | ||
|
||
import importlib | ||
import json | ||
import sys | ||
|
||
FLAGS = flags.FLAGS | ||
flags.DEFINE_integer("parallel", 1, "How many instances to run in parallel.") | ||
flags.DEFINE_integer("step_mul", 8, "How many game steps per observation.") | ||
flags.DEFINE_string("replays", None, "Path to a directory of replays.") | ||
flags.DEFINE_string("parser", "pysc2.replay_parsers.base_parser.BaseParser", | ||
"Which parser to use in scrapping replay data") | ||
flags.DEFINE_string("data_dir", None, | ||
"Path to directory to save replay data from replay parser") | ||
flags.DEFINE_integer("screen_resolution", 16, | ||
"Resolution for screen feature layers.") | ||
flags.DEFINE_integer("minimap_resolution", 16, | ||
"Resolution for minimap feature layers.") | ||
flags.mark_flag_as_required("replays") | ||
|
||
|
||
size = point.Point(16, 16) | ||
interface = sc_pb.InterfaceOptions( | ||
raw=True, score=False, | ||
feature_layer=sc_pb.SpatialCameraSetup(width=24)) | ||
size.assign_to(interface.feature_layer.resolution) | ||
size.assign_to(interface.feature_layer.minimap_resolution) | ||
|
||
|
||
def sorted_dict_str(d): | ||
return "{%s}" % ", ".join("%s: %s" % (k, d[k]) | ||
for k in sorted(d, key=d.get, reverse=True)) | ||
|
||
|
||
class ReplayStats(object): | ||
"""Summary stats of the replays seen so far.""" | ||
|
||
def __init__(self): | ||
self.replays = 0 | ||
self.steps = 0 | ||
self.camera_move = 0 | ||
self.select_pt = 0 | ||
self.select_rect = 0 | ||
self.control_group = 0 | ||
self.maps = collections.defaultdict(int) | ||
self.races = collections.defaultdict(int) | ||
self.unit_ids = collections.defaultdict(int) | ||
self.valid_abilities = collections.defaultdict(int) | ||
self.made_abilities = collections.defaultdict(int) | ||
self.valid_actions = collections.defaultdict(int) | ||
self.made_actions = collections.defaultdict(int) | ||
self.crashing_replays = set() | ||
self.invalid_replays = set() | ||
|
||
def merge(self, other): | ||
"""Merge another ReplayStats into this one.""" | ||
def merge_dict(a, b): | ||
for k, v in six.iteritems(b): | ||
a[k] += v | ||
|
||
self.replays += other.replays | ||
self.steps += other.steps | ||
self.camera_move += other.camera_move | ||
self.select_pt += other.select_pt | ||
self.select_rect += other.select_rect | ||
self.control_group += other.control_group | ||
merge_dict(self.maps, other.maps) | ||
merge_dict(self.races, other.races) | ||
merge_dict(self.unit_ids, other.unit_ids) | ||
merge_dict(self.valid_abilities, other.valid_abilities) | ||
merge_dict(self.made_abilities, other.made_abilities) | ||
merge_dict(self.valid_actions, other.valid_actions) | ||
merge_dict(self.made_actions, other.made_actions) | ||
self.crashing_replays |= other.crashing_replays | ||
self.invalid_replays |= other.invalid_replays | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. merge moved to parser class |
||
def __str__(self): | ||
len_sorted_dict = lambda s: (len(s), sorted_dict_str(s)) | ||
len_sorted_list = lambda s: (len(s), sorted(s)) | ||
return "\n\n".join(( | ||
"Replays: %s, Steps total: %s" % (self.replays, self.steps), | ||
"Camera move: %s, Select pt: %s, Select rect: %s, Control group: %s" % ( | ||
self.camera_move, self.select_pt, self.select_rect, | ||
self.control_group), | ||
"Maps: %s\n%s" % len_sorted_dict(self.maps), | ||
"Races: %s\n%s" % len_sorted_dict(self.races), | ||
"Unit ids: %s\n%s" % len_sorted_dict(self.unit_ids), | ||
"Valid abilities: %s\n%s" % len_sorted_dict(self.valid_abilities), | ||
"Made abilities: %s\n%s" % len_sorted_dict(self.made_abilities), | ||
"Valid actions: %s\n%s" % len_sorted_dict(self.valid_actions), | ||
"Made actions: %s\n%s" % len_sorted_dict(self.made_actions), | ||
"Crashing replays: %s\n%s" % len_sorted_list(self.crashing_replays), | ||
"Invalid replays: %s\n%s" % len_sorted_list(self.invalid_replays), | ||
)) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. str moved to parser class |
||
interface = sc_pb.InterfaceOptions() | ||
interface.raw = True | ||
interface.score = False | ||
interface.feature_layer.width = 24 | ||
interface.feature_layer.resolution.x = FLAGS.screen_resolution | ||
interface.feature_layer.resolution.y = FLAGS.screen_resolution | ||
interface.feature_layer.minimap_resolution.x = FLAGS.minimap_resolution | ||
interface.feature_layer.minimap_resolution.y = FLAGS.minimap_resolution | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Copied from play.py to be consistent |
||
class ProcessStats(object): | ||
"""Stats for a worker process.""" | ||
|
||
def __init__(self, proc_id): | ||
def __init__(self, proc_id, parser_cls): | ||
self.proc_id = proc_id | ||
self.time = time.time() | ||
self.stage = "" | ||
self.replay = "" | ||
self.replay_stats = ReplayStats() | ||
self.parser = parser_cls() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. renamed replay_stats to parser, updated throughout code |
||
|
||
def update(self, stage): | ||
self.time = time.time() | ||
|
@@ -141,34 +86,18 @@ def update(self, stage): | |
def __str__(self): | ||
return ("[%2d] replay: %10s, replays: %5d, steps: %7d, game loops: %7s, " | ||
"last: %12s, %3d s ago" % ( | ||
self.proc_id, self.replay, self.replay_stats.replays, | ||
self.replay_stats.steps, | ||
self.replay_stats.steps * FLAGS.step_mul, self.stage, | ||
self.proc_id, self.replay, self.parser.replays, | ||
self.parser.steps, | ||
self.parser.steps * FLAGS.step_mul, self.stage, | ||
time.time() - self.time)) | ||
|
||
|
||
def valid_replay(info, ping): | ||
"""Make sure the replay isn't corrupt, and is worth looking at.""" | ||
if (info.HasField("error") or | ||
info.base_build != ping.base_build or # different game version | ||
info.game_duration_loops < 1000 or | ||
len(info.player_info) != 2): | ||
# Probably corrupt, or just not interesting. | ||
return False | ||
for p in info.player_info: | ||
if p.player_apm < 10 or p.player_mmr < 1000: | ||
# Low APM = player just standing around. | ||
# Low MMR = corrupt replay or player who is weak. | ||
return False | ||
return True | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. valid_replay moved to parser class |
||
class ReplayProcessor(multiprocessing.Process): | ||
"""A Process that pulls replays and processes them.""" | ||
|
||
def __init__(self, proc_id, run_config, replay_queue, stats_queue): | ||
def __init__(self, proc_id, run_config, replay_queue, stats_queue, parser_cls): | ||
super(ReplayProcessor, self).__init__() | ||
self.stats = ProcessStats(proc_id) | ||
self.stats = ProcessStats(proc_id, parser_cls) | ||
self.run_config = run_config | ||
self.replay_queue = replay_queue | ||
self.stats_queue = stats_queue | ||
|
@@ -192,7 +121,7 @@ def run(self): | |
self._print("Empty queue, returning") | ||
return | ||
try: | ||
replay_name = os.path.basename(replay_path)[:10] | ||
replay_name = os.path.basename(replay_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed 10 character truncating to full replay name There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason it was just the first 10 is that I was running mainly over replays that were sha1 named, so very long and the prefix gave enough uniqueness. I'm fine with showing the full as long as it's readable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, I've left the full replay name but I'll leave it up to you to decide on readability. The only catch is that data files are uniquely named from their replay name and will be overwritten if there are colliding names. |
||
self.stats.replay = replay_name | ||
self._print("Got replay: %s" % replay_path) | ||
self._update_stage("open replay file") | ||
|
@@ -202,12 +131,12 @@ def run(self): | |
self._print((" Replay Info %s " % replay_name).center(60, "-")) | ||
self._print(info) | ||
self._print("-" * 60) | ||
if valid_replay(info, ping): | ||
self.stats.replay_stats.maps[info.map_name] += 1 | ||
if self.stats.parser.valid_replay(info, ping): | ||
self.stats.parser.maps[info.map_name] += 1 | ||
for player_info in info.player_info: | ||
race_name = sc_common.Race.Name( | ||
player_info.player_info.race_actual) | ||
self.stats.replay_stats.races[race_name] += 1 | ||
self.stats.parser.races[race_name] += 1 | ||
map_data = None | ||
if info.local_map_path: | ||
self._update_stage("open map file") | ||
|
@@ -216,16 +145,16 @@ def run(self): | |
self._print("Starting %s from player %s's perspective" % ( | ||
replay_name, player_id)) | ||
self.process_replay(controller, replay_data, map_data, | ||
player_id) | ||
player_id, info, replay_name) | ||
else: | ||
self._print("Replay is invalid.") | ||
self.stats.replay_stats.invalid_replays.add(replay_name) | ||
self.stats.parser.invalid_replays.add(replay_name) | ||
finally: | ||
self.replay_queue.task_done() | ||
self._update_stage("shutdown") | ||
except (protocol.ConnectionError, protocol.ProtocolError, | ||
remote_controller.RequestError): | ||
self.stats.replay_stats.crashing_replays.add(replay_name) | ||
self.stats.parser.crashing_replays.add(replay_name) | ||
except KeyboardInterrupt: | ||
return | ||
|
||
|
@@ -237,7 +166,8 @@ def _update_stage(self, stage): | |
self.stats.update(stage) | ||
self.stats_queue.put(self.stats) | ||
|
||
def process_replay(self, controller, replay_data, map_data, player_id): | ||
def process_replay(self, controller, replay_data, map_data, player_id, info, replay_name): | ||
print(replay_name) | ||
"""Process a single replay, updating the stats.""" | ||
self._update_stage("start_replay") | ||
controller.start_replay(sc_pb.RequestStartReplay( | ||
|
@@ -248,53 +178,42 @@ def process_replay(self, controller, replay_data, map_data, player_id): | |
|
||
feat = features.Features(controller.game_info()) | ||
|
||
self.stats.replay_stats.replays += 1 | ||
self.stats.parser.replays += 1 | ||
self._update_stage("step") | ||
controller.step() | ||
data = [] | ||
while True: | ||
self.stats.replay_stats.steps += 1 | ||
self.stats.parser.steps += 1 | ||
self._update_stage("observe") | ||
obs = controller.observe() | ||
|
||
for action in obs.actions: | ||
act_fl = action.action_feature_layer | ||
if act_fl.HasField("unit_command"): | ||
self.stats.replay_stats.made_abilities[ | ||
act_fl.unit_command.ability_id] += 1 | ||
if act_fl.HasField("camera_move"): | ||
self.stats.replay_stats.camera_move += 1 | ||
if act_fl.HasField("unit_selection_point"): | ||
self.stats.replay_stats.select_pt += 1 | ||
if act_fl.HasField("unit_selection_rect"): | ||
self.stats.replay_stats.select_rect += 1 | ||
if action.action_ui.HasField("control_group"): | ||
self.stats.replay_stats.control_group += 1 | ||
|
||
try: | ||
func = feat.reverse_action(action).function | ||
except ValueError: | ||
func = -1 | ||
self.stats.replay_stats.made_actions[func] += 1 | ||
|
||
for valid in obs.observation.abilities: | ||
self.stats.replay_stats.valid_abilities[valid.ability_id] += 1 | ||
|
||
for u in obs.observation.raw_data.units: | ||
self.stats.replay_stats.unit_ids[u.unit_type] += 1 | ||
|
||
for ability_id in feat.available_actions(obs.observation): | ||
self.stats.replay_stats.valid_actions[ability_id] += 1 | ||
|
||
if obs.player_result: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. action stats parsing moved to custom ActionParser class |
||
# If parser.parse_step returns, whatever is returned is appended | ||
# to a data list, and this data list is saved to a json file | ||
# in the data_dir directory with filename = replay_name_player_id.json | ||
parsed_data = self.stats.parser.parse_step(obs,feat,info) | ||
if parsed_data: | ||
data.append(parsed_data) | ||
|
||
if obs.player_result: | ||
# Save scraped replay data to file at end of replay if parser returns | ||
# and data_dir provided | ||
if data: | ||
if FLAGS.data_dir: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Data only saved to file if data_dir provided from user |
||
stripped_replay_name = replay_name.split(".")[0] | ||
data_file = os.path.join(FLAGS.data_dir, | ||
stripped_replay_name + "_" + str(player_id) + '.json') | ||
with open(data_file,'w') as outfile: | ||
json.dump(data,outfile) | ||
else: | ||
print("Please provide a directory as data_dir to save scrapped data files") | ||
break | ||
|
||
self._update_stage("step") | ||
controller.step(FLAGS.step_mul) | ||
|
||
|
||
def stats_printer(stats_queue): | ||
def stats_printer(stats_queue, parser_cls): | ||
"""A thread that consumes stats_queue and prints them every 10 seconds.""" | ||
proc_stats = [ProcessStats(i) for i in range(FLAGS.parallel)] | ||
proc_stats = [ProcessStats(i,parser_cls) for i in range(FLAGS.parallel)] | ||
print_time = start_time = time.time() | ||
width = 107 | ||
|
||
|
@@ -312,12 +231,12 @@ def stats_printer(stats_queue): | |
except queue.Empty: | ||
pass | ||
|
||
replay_stats = ReplayStats() | ||
parser = parser_cls() | ||
for s in proc_stats: | ||
replay_stats.merge(s.replay_stats) | ||
parser.merge(s.parser) | ||
|
||
print((" Summary %0d secs " % (print_time - start_time)).center(width, "=")) | ||
print(replay_stats) | ||
print(parser) | ||
print(" Process stats ".center(width, "-")) | ||
print("\n".join(str(s) for s in proc_stats)) | ||
print("=" * width) | ||
|
@@ -333,11 +252,14 @@ def main(unused_argv): | |
"""Dump stats about all the actions that are in use in a set of replays.""" | ||
run_config = run_configs.get() | ||
|
||
parser_module, parser_name = FLAGS.parser.rsplit(".", 1) | ||
parser_cls = getattr(importlib.import_module(parser_module), parser_name) | ||
|
||
if not gfile.Exists(FLAGS.replays): | ||
sys.exit("{} doesn't exist.".format(FLAGS.replays)) | ||
|
||
stats_queue = multiprocessing.Queue() | ||
stats_thread = threading.Thread(target=stats_printer, args=(stats_queue,)) | ||
stats_thread = threading.Thread(target=stats_printer, args=(stats_queue,parser_cls)) | ||
stats_thread.start() | ||
try: | ||
# For some reason buffering everything into a JoinableQueue makes the | ||
|
@@ -355,7 +277,7 @@ def main(unused_argv): | |
replay_queue_thread.start() | ||
|
||
for i in range(FLAGS.parallel): | ||
p = ReplayProcessor(i, run_config, replay_queue, stats_queue) | ||
p = ReplayProcessor(i, run_config, replay_queue, stats_queue, parser_cls) | ||
p.daemon = True | ||
p.start() | ||
time.sleep(1) # Stagger startups, otherwise they seem to conflict somehow | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2017 Google Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS-IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added resolution args to be consistent with play.py