Skip to content

Commit

Permalink
Merge pull request #4 from tc20042008/develop
Browse files Browse the repository at this point in the history
push recent updates
  • Loading branch information
Aurelius84 authored Jul 23, 2024
2 parents 8844a05 + 49fde7a commit 00a7a4f
Show file tree
Hide file tree
Showing 103 changed files with 9,623 additions and 27,066 deletions.
7 changes: 2 additions & 5 deletions athena/_constraint_unittests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from athena.generators.constraint_unittests_generator import (
ConstraintUnittestsGenerator,
)


from athena.util.ir_program_util import IsBackwardProgram, GetProgramId
import athena.ir.ir_op as ir_op
import athena.ir.ir_type as ir_type
from absl import app
from absl import flags
Expand Down Expand Up @@ -85,9 +84,7 @@ def GetPyVarName(uid_and_op):
uid_and_ops = [
(program_id, op)
for program_id, op in uid_and_ops
if op_example_inputs_meta_getter.HasAllInputs(
program_id, op.op_id, num_inputs=len(op.input_types)
)
if op_example_inputs_meta_getter.HasAllInputs(program_id, op)
if all(
isinstance(input_type, valid_operand_types)
for input_type in op.input_types
Expand Down
6 changes: 2 additions & 4 deletions athena/_primitive_op_unittests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from athena.util.load_pir_py_classes import GetProgramClasses, GetClasses

from athena.util.op_example_inputs_meta_getter import (
MakeOpExampleInputsMetaGetter,
)
Expand All @@ -8,6 +7,7 @@
PrimitiveOpUnittestsGenerator,
)
from athena.util.ir_program_util import IsBackwardProgram, GetProgramId
import athena.ir.ir_op as ir_op
import athena.ir.ir_type as ir_type
from absl import app
from absl import flags
Expand Down Expand Up @@ -91,9 +91,7 @@ def GetPyVarName(uid_and_op):
uid_and_ops = [
(program_id, op)
for program_id, op in uid_and_ops
if op_example_inputs_meta_getter.HasAllInputs(
program_id, op.op_id, num_inputs=len(op.input_types)
)
if op_example_inputs_meta_getter.HasAllInputs(program_id, op)
if all(
isinstance(input_type, valid_operand_types)
for input_type in op.input_types
Expand Down
69 changes: 40 additions & 29 deletions athena/constraint_unittests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,44 +14,55 @@


def main(argv):
WithTempDirectory(Main)
WithTempDirectory(Main)


def WithTempDirectory(f):
if FLAGS.tmp_dir == "":
tmp_dir = tempfile.mkdtemp()
return f(tmp_dir)
shutil.rmtree(tmp_dir)
else:
assert os.path.isdir(FLAGS.tmp_dir)
return f(FLAGS.tmp_dir)
if FLAGS.tmp_dir == "":
tmp_dir = tempfile.mkdtemp()
return f(tmp_dir)
shutil.rmtree(tmp_dir)
else:
assert os.path.isdir(FLAGS.tmp_dir)
return f(FLAGS.tmp_dir)


def Main(tmp_dir):
assert os.path.isdir(FLAGS.output_dir), f"directory {FLAGS.output_dir} not existed."
shutil.copyfile(FLAGS.ir_programs, f"{tmp_dir}/original_programs.py")
shutil.copyfile(FLAGS.example_inputs, f"{tmp_dir}/programs_example_input_tensor_meta.py")
file_prefix = "tmp_op_example_input_"
for file in glob.glob(f"{tmp_dir}/{file_prefix}*.py"):
os.remove(file)
for file in glob.glob(f"{FLAGS.output_dir}/test_constraint_*.py"):
os.remove(file)
System(f"{sys.executable} -m athena.op_example_input_meta_script --output_file_prefix={file_prefix} --input_dir={tmp_dir} --output_dir={tmp_dir}")
System(f"{sys.executable} -m athena.op_example_input_meta_result --input_file_prefix={file_prefix} --input_dir={tmp_dir} --output_dir={tmp_dir}")
System(f"{sys.executable} -m athena._constraint_unittests --input_dir={tmp_dir} --output_dir={FLAGS.output_dir}")
sys.exit(exit_code)
assert os.path.isdir(FLAGS.output_dir), f"directory {FLAGS.output_dir} not existed."
shutil.copyfile(FLAGS.ir_programs, f"{tmp_dir}/original_programs.py")
shutil.copyfile(
FLAGS.example_inputs, f"{tmp_dir}/programs_example_input_tensor_meta.py"
)
file_prefix = "tmp_op_example_input_"
for file in glob.glob(f"{tmp_dir}/{file_prefix}*.py"):
os.remove(file)
for file in glob.glob(f"{FLAGS.output_dir}/test_constraint_*.py"):
os.remove(file)
System(
f"{sys.executable} -m athena.op_example_input_meta_script --output_file_prefix={file_prefix} --input_dir={tmp_dir} --output_dir={tmp_dir}"
)
System(
f"{sys.executable} -m athena.op_example_input_meta_result --input_file_prefix={file_prefix} --input_dir={tmp_dir} --output_dir={tmp_dir}"
)
System(
f"{sys.executable} -m athena._constraint_unittests --input_dir={tmp_dir} --output_dir={FLAGS.output_dir}"
)
sys.exit(exit_code)


exit_code = 0


def System(cmd):
print(cmd, file=sys.stderr)
ret = os.system(cmd)
global exit_code
if exit_code != 0:
return
if ret == 0:
return
exit_code = ret
print(cmd, file=sys.stderr)
ret = os.system(cmd)
global exit_code
if exit_code != 0:
return
if ret == 0:
return
exit_code = ret


if __name__ == "__main__":
app.run(main)
app.run(main)
48 changes: 26 additions & 22 deletions athena/fusion_op_unittests.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from athena.util.load_pir_py_classes import GetProgramClasses
from athena.generators.fusion_op_unittest_generator import (
FusionOpUnittestGenerator
)
from athena.generators.fusion_op_unittest_generator import FusionOpUnittestGenerator
import sys
from absl import app
from absl import flags
Expand All @@ -11,33 +9,39 @@

flags.DEFINE_string("output_dir", "./output-dir", "output directory.")


def main(argv):
for name, unittest in GetOutputUnittests(argv[1]):
sha256sum = GetSha256sum(unittest)
filepath = f"{FLAGS.output_dir}/test_{sha256sum[0:32]}.py"
WriteToFile(filepath, unittest)
PrintToTerminal(name, filepath, unittest)
for name, unittest in GetOutputUnittests(argv[1]):
sha256sum = GetSha256sum(unittest)
filepath = f"{FLAGS.output_dir}/test_{sha256sum[0:32]}.py"
WriteToFile(filepath, unittest)
PrintToTerminal(name, filepath, unittest)


def GetSha256sum(content):
m = hashlib.sha256()
m.update(content.encode())
return m.hexdigest()
m = hashlib.sha256()
m.update(content.encode())
return m.hexdigest()


def PrintToTerminal(name, filepath, unittest):
print("# file-splitter-begin-fusion-op-name: ", name, filepath)
print(unittest)
print("# file-splitter--end--fusion-op-name: ", name, filepath)
print("# file-splitter-begin-fusion-op-name: ", name, filepath)
print(unittest)
print("# file-splitter--end--fusion-op-name: ", name, filepath)


def WriteToFile(filepath, unittest):
with open(filepath, "w") as f:
f.write(unittest)
with open(filepath, "w") as f:
f.write(unittest)


def GetOutputUnittests(input_file_path):
for cls in GetProgramClasses(input_file_path):
ir_program = cls()
generator = FusionOpUnittestGenerator()
op_name2unittest = generator.Generate(ir_program)
yield from op_name2unittest.items()
for cls in GetProgramClasses(input_file_path):
ir_program = cls()
generator = FusionOpUnittestGenerator()
op_name2unittest = generator.Generate(ir_program)
yield from op_name2unittest.items()


if __name__ == "__main__":
app.run(main)
app.run(main)
Loading

0 comments on commit 00a7a4f

Please sign in to comment.