Skip to content

Commit

Permalink
finished first half of codegen test macro
Browse files Browse the repository at this point in the history
  • Loading branch information
vcanumalla committed Jul 27, 2021
1 parent c293d83 commit 0251c27
Showing 1 changed file with 175 additions and 144 deletions.
319 changes: 175 additions & 144 deletions src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1939,7 +1939,7 @@ mod tests {
relay_outputs
}
macro_rules! codegen_test {
($env: expr, $code_expr: expr, $ground_truth : expr, $c_code : expr) => {
($env : expr, $code_expr : expr, $hw_map : expr, $ground_truth : expr, $c_code : expr) => {
let mut map: HashMap<String, Vec<usize>> = HashMap::default();
let env = $env;
let mut names = Vec::new();
Expand All @@ -1962,7 +1962,7 @@ mod tests {
let code = codegen(
&egraph,
id,
&HashMap::default(),
&$hw_map,
"compiled_function",
"",
&names,
Expand Down Expand Up @@ -1998,7 +1998,7 @@ mod tests {
format!(
"{}{}{}",
acc,
c_assignment_string("", arg.0, DType::Fp32, &arg.1.view()),
c_assignment_string("", arg.0, DType::Fp32, &arg.1.clone().into_dyn().view()),
"\n"
)
}),
Expand Down Expand Up @@ -2059,6 +2059,132 @@ mod tests {
.expect("Could not convert stderr to UTF8")
);
};
($env : expr, $code_expr : expr, $ground_truth : expr, $c_code : expr) => {
let mut map: HashMap<String, Vec<usize>> = HashMap::default();
let env = $env;
let mut names = Vec::new();
let expr: RecExpr<Language> = RecExpr::from_str($code_expr.as_str()).unwrap();
if (env.len() == 1) {
let name = env[0].0;
map.insert(name.to_string(), env[0].1.shape().to_vec());
names.push(name);
} else {
for n in 0..(env.len()) {
let name = env[n].0.clone();
let shape = env[n].1.shape().to_vec().clone();

map.insert(name.to_string(), shape);
names.push(name);
}
}
let mut egraph = EGraph::new(MyAnalysis { name_to_shape: map });
let id = egraph.add_expr(&expr);
let mut hw_map = HashMap::default();
hw_map.insert(id, 0);
let code = codegen(
&egraph,
id,
&hw_map,
"compiled_function",
"",
&names,
&generate_worklist_for_codegen(&egraph, id),
true,
);
let ground_truth = $ground_truth;
let c_func = format!(
"
int main() {{
{}(out{});
for (int i = 0; i < {}; i++) {{
assert(((float*)ground_truth)[i] == ((float*)out)[i]);
}}
}}
",
"compiled_function",
names
.iter()
.fold(String::new(), |acc, &arg| acc + "," + arg),
ground_truth.shape().iter().product::<usize>()
);
let declarations = format!(
"
#include <assert.h>
#include \"{}\"
{}
{}
{}
{}",
$c_code,
env.iter().fold(String::new(), |acc, arg| {
format!(
"{}{}{}",
acc,
c_assignment_string("", arg.0, DType::Fp32, &arg.1.clone().into_dyn().view()),
"\n"
)
}),
c_assignment_string("", "ground_truth", DType::Fp32, &ground_truth.view()),
c_assignment_string(
"",
"out",
DType::Fp32,
&ndarray::ArrayD::<f32>::zeros(ground_truth.shape()).view()
),
code
);

let main_code = format!(
"
{}
{}",
declarations, c_func
);
let main_c_filepath = std::env::temp_dir().with_file_name(format!(
"{}-test-{}.c",
"compiled_function",
std::time::SystemTime::now().elapsed().unwrap().as_nanos()
));
let binary_filepath = std::env::temp_dir().with_file_name(format!(
"{}-test-{}",
"compiled_function",
std::time::SystemTime::now().elapsed().unwrap().as_nanos()
));

File::create(&main_c_filepath)
.unwrap()
.write_all(main_code.as_bytes())
.unwrap();

let result = Command::new("gcc")
.arg("-Werror")
.arg("-g")
.arg("-o")
.arg(&binary_filepath)
.arg(&main_c_filepath)
.output()
.unwrap();

assert!(
result.status.success(),
"{}",
std::str::from_utf8(result.stderr.as_slice())
.expect("Could not convert stderr to UTF8")
);

let result = Command::new(&binary_filepath).output().unwrap();

assert!(
result.status.success(),
"{}",
std::str::from_utf8(result.stderr.as_slice())
.expect("Could not convert stderr to UTF8")
);
};
($env: expr, $code_expr: expr, $ground_truth : expr) => {
codegen_test!($env, $code_expr, HashMap::default(), $ground_truth, "");
};

}
#[test]
fn tranpose() {
Expand All @@ -2078,8 +2204,7 @@ mod tests {
"(access-transpose (access-tensor t) (list {}))",
permutation.iter().map(|x| x.to_string()).join(" ")
),
input_transposed,
""
input_transposed
);
}

Expand Down Expand Up @@ -2118,147 +2243,55 @@ mod tests {
"(access-concatenate (access-tensor t0) (access-tensor t1) {})",
concat_axis
),
concatted,
""
concatted
);
}

#[test]
fn systolic_array() {
let shape0 = vec![2, 10];
let shape1 = vec![10, 15];

let input0 = ndarray::ArrayD::from_shape_vec(
shape0.clone(),
(0..shape0.iter().product::<usize>()).collect(),
)
.unwrap()
.into_dimensionality::<ndarray::Ix2>()
.unwrap();
let input1 = ndarray::ArrayD::from_shape_vec(
shape1.clone(),
(0..shape1.iter().product::<usize>()).collect(),
)
.unwrap()
.into_dimensionality::<ndarray::Ix2>()
.unwrap();
let multiplied = input0.dot(&input1).into_dyn();

let expr = RecExpr::from_str(
"
(systolic-array 10 15
(access (access-tensor t0) 1)
(access (access-tensor t1) 0)
)",
)
.unwrap();

let mut map = HashMap::default();
map.insert("t0".to_string(), shape0.clone());
map.insert("t1".to_string(), shape1.clone());

let mut egraph = EGraph::new(MyAnalysis { name_to_shape: map });
let id = egraph.add_expr(&expr);

let mut hw_map = HashMap::default();
hw_map.insert(id, 0);

let code = codegen(
&egraph,
id,
&hw_map,
"systolic_array",
"",
&vec!["t0", "t1"],
&generate_worklist_for_codegen(&egraph, id),
true,
);

let main_code = format!(
"
#include <assert.h>
#include \"{}\"
{}
{}
{}
{}
{}
int main() {{
systolic_array(out, t0, t1);
for (int i = 0; i < {}; i++) {{
assert(((float*)result)[i] == ((float*)out)[i]);
}}
}}
",
PathBuf::from_str(
format!(
"{}/{}/{}/{}",
env!("CARGO_MANIFEST_DIR"),
"data",
"codegen-mlp",
"rtml_systolic_array_weight_stationary.c"
let input_list = vec![
(
"t0",
ndarray::ArrayD::from_shape_vec(
vec![2, 10].clone(),
(0..vec![2, 10].iter().product::<usize>()).collect(),
)
.as_str()
)
.unwrap()
.to_string_lossy(),
c_assignment_string("", "t0", DType::Fp32, &input0.into_dyn().view()),
c_assignment_string("", "t1", DType::Fp32, &input1.into_dyn().view()),
c_assignment_string("", "result", DType::Fp32, &multiplied.view()),
c_assignment_string(
"",
"out",
DType::Fp32,
&ndarray::ArrayD::<f32>::zeros(multiplied.shape()).view()
.unwrap()
.into_dimensionality::<ndarray::Ix2>()
.unwrap()
),
code,
multiplied.shape().iter().product::<usize>()
);

let main_c_filepath = std::env::temp_dir().with_file_name(format!(
"systolic-array-test-{}.c",
std::time::SystemTime::now().elapsed().unwrap().as_nanos()
));
println!("{}", main_c_filepath.to_string_lossy());

let binary_filepath = std::env::temp_dir().with_file_name(format!(
"systolic-array-test-{}",
std::time::SystemTime::now().elapsed().unwrap().as_nanos()
));
println!("{}", binary_filepath.to_string_lossy());

File::create(&main_c_filepath)
.unwrap()
.write_all(main_code.as_bytes())
.unwrap();

let result = Command::new("gcc")
.arg("-Werror")
.arg("-g")
.arg("-o")
.arg(&binary_filepath)
.arg(&main_c_filepath)
.output()
.unwrap();

assert!(
result.status.success(),
"{}",
std::str::from_utf8(result.stderr.as_slice())
.expect("Could not convert stderr to UTF8")
);

let result = Command::new(&binary_filepath).output().unwrap();

assert!(
result.status.success(),
"{}",
std::str::from_utf8(result.stderr.as_slice())
.expect("Could not convert stderr to UTF8")
);
(
"t1",
ndarray::ArrayD::from_shape_vec(
vec![10, 15].clone(),
(0..vec![10, 15].iter().product::<usize>()).collect(),
)
.unwrap()
.into_dimensionality::<ndarray::Ix2>()
.unwrap()
)];
let multiplied = input_list[0].1.clone().dot(&input_list[1].1.clone()).into_dyn();
codegen_test!(
input_list.clone(),
format!("
(systolic-array 10 15
(access (access-tensor t0) 1)
(access (access-tensor t1) 0)
)"
),
multiplied,
PathBuf::from_str(
format!(
"{}/{}/{}/{}",
env!("CARGO_MANIFEST_DIR"),
"data",
"codegen-mlp",
"rtml_systolic_array_weight_stationary.c"
)
.as_str()
)
.unwrap()
.to_string_lossy()
);
}
#[test]
fn pad() {
Expand Down Expand Up @@ -2296,8 +2329,7 @@ int main() {{
"(access-pad (access-tensor t) {} {} {} {})",
pad_type, pad_axis, pad_before, pad_after
),
padded,
""
padded
);
}
#[test]
Expand Down Expand Up @@ -2340,8 +2372,7 @@ int main() {{
(access-slice (access-tensor t) {} {} {})",
slice_axis, low, high
),
sliced,
""
sliced
);
}

Expand Down

0 comments on commit 0251c27

Please sign in to comment.