Skip to content

Commit

Permalink
pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffreyangus committed Jul 20, 2023
1 parent 0fdcb98 commit 561cec2
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 12 deletions.
5 changes: 2 additions & 3 deletions ludwig/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,12 +638,13 @@ def on_epoch_end(self, trainer, progress_tracker, save_path):
test_set=test_set,
save_path=model_dir,
)
(self.model, train_trainset_stats, train_valiset_stats, train_testset_stats) = train_stats

# Calibrates output feature probabilities on validation set if calibration is enabled.
# Must be done after training, and before final model parameters are saved.
if self.backend.is_coordinator():
calibrator = Calibrator(
trainer.model,
self.model,
self.backend,
batch_size=trainer.eval_batch_size,
)
Expand All @@ -667,7 +668,6 @@ def on_epoch_end(self, trainer, progress_tracker, save_path):
if not skip_save_model:
# ensure that any changes to the model object held by the
# trainer class are reflected in the model in this class.
self.model = trainer.model
self.model.save(model_dir)

# Evaluation Frequency
Expand All @@ -687,7 +687,6 @@ def on_epoch_end(self, trainer, progress_tracker, save_path):
# List[TrainerMetric], with one entry per training checkpoint, according to steps_per_checkpoint.
# We reduce the dictionary of TrainerMetrics to a simple list of floats for interfacing with Ray
# Tune.
(self.model, train_trainset_stats, train_valiset_stats, train_testset_stats) = train_stats
train_stats = TrainingStats(
metric_utils.reduce_trainer_metrics_dict(train_trainset_stats),
metric_utils.reduce_trainer_metrics_dict(train_valiset_stats),
Expand Down
9 changes: 0 additions & 9 deletions ludwig/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,6 @@ def _has_ray():
return False


def is_ray_backend(backend) -> bool:
if isinstance(backend, str):
return backend == "ray"
elif isinstance(backend, dict):
return backend.get("type", "local") == "ray"
else:
return False


def get_local_backend(**kwargs):
return LocalBackend(**kwargs)

Expand Down
1 change: 1 addition & 0 deletions ludwig/models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def initialize_adapter(self):
logger.info("==================================================")

def prepare_for_training(self):
# TODO: this implementation will not work if resuming from a previous checkpoint. Need to fix this.
self.initialize_adapter()

def to_device(self, device):
Expand Down

0 comments on commit 561cec2

Please sign in to comment.