Skip to content

Commit

Permalink
[DICP][ascend] bugfix for llama finetune (#631)
Browse files Browse the repository at this point in the history
Co-authored-by: Pan Daoxin <[email protected]>
  • Loading branch information
tangzhiyi11 and pdx1989 authored Jan 15, 2024
1 parent 9c231e7 commit 1c5325f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
5 changes: 4 additions & 1 deletion dicp/dicp/vendor/AscendGraph/compile_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,11 @@ def build_graph(self, output_path, graph_path):
def get_compile_result(self):
if (not os.path.exists(self._model_path[0]) and not os.path.exists(self._model_path[1])):
self.build_graph(self._output_graph_path, self._input_path)
origin_graph_path = self._output_graph_path
if not os.path.exists(self._output_graph_path + '.om'):
self._output_graph_path += '_linux_x86_64'
self._output_graph_path = origin_graph_path + '_linux_x86_64'
if not os.path.exists(self._output_graph_path + '.om'):
self._output_graph_path = origin_graph_path + '_linux_aarch64'
assert (os.path.exists(self._output_graph_path + '.om'))
from dicp.vendor.AscendGraph.codegen.load_and_run import AscendModel
return AscendModel(self._local_rank, self._output_graph_path + '.om')
7 changes: 6 additions & 1 deletion dicp/dicp/vendor/AscendGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,12 @@ def full(self, dims, value, dtype=torch.float32, layout=torch.strided,
if isinstance(value, torch.fx.proxy.Proxy) and hasattr(value.node, 'meta'):
value = value.node.meta['val']
dims = self.get_shape_proxy(dims)
value = self.get_proxy(ascend_op.Const, ([value], torch_dtype, []))

# temporarily split the path for dynamic/static shape cases
if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0:
value = self.get_proxy(ascend_op.Const, ([value], torch_dtype, []))
else:
value = self.common_process_scalar(value, torch_dtype)
return self.get_proxy(ascend_op.Fill, (dims, value))

@register_conversion(torch.ops.aten.fill.Scalar)
Expand Down

0 comments on commit 1c5325f

Please sign in to comment.