Skip to content

Commit

Permalink
Same as previous
Browse files Browse the repository at this point in the history
  • Loading branch information
LHBO committed Apr 16, 2024
1 parent fe9c4a5 commit 4e751dd
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions python/shaprpy/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def explain(
MSEv_uniform_comb_weights = MSEv_uniform_comb_weights,
timing = timing,
verbose = verbose,
is_python = True,
is_python=True,
**kwargs
)

Expand Down Expand Up @@ -232,7 +232,7 @@ def batch_prepare_vS_MC(S, rinternal, model, predict_model):
dt_vS = shapr.compute_MCint(dt)

if keep_samp_for_vS:
return ro.ListVector({'dt_vS': dt_vS, 'dt_samp_for_vS': dt})
return ro.ListVector({'dt_vS':dt_vS, 'dt_samp_for_vS':dt})
else:
return dt_vS

Expand Down Expand Up @@ -263,7 +263,7 @@ def get_feature_specs(get_model_specs, model):

if not isinstance(feature_specs, dict):
raise ValueError(f'`get_model_specs` returned an object of type `{type(feature_specs)}`, but it should be of type `dict`')
if set(feature_specs.keys()) != set(["labels", "classes", "factor_levels"]):
if set(feature_specs.keys()) != set(["labels","classes","factor_levels"]):
raise ValueError(f'`get_model_specs` should return a `dict` with keys ["labels","classes","factor_levels"], but found keys {list(feature_specs.keys())}')

if feature_specs is None:
Expand All @@ -277,7 +277,7 @@ def strvec_or_na(v):
return strvec
def listvec_or_na(v):
if v is None: return NA
return ro.ListVector({k: list(val) for k,val in v.items()})
return ro.ListVector({k:list(val) for k,val in v.items()})

rfeature_specs = ro.ListVector({
'labels': py2r_or_na(feature_specs['labels']),
Expand All @@ -302,7 +302,7 @@ def get_predict_model(x_test, predict_model, model):
try:
tmp = py2r(predict_model(model, x_test))
except Exception as e:
raise RuntimeError(f'The predict_model function of class `{model_class0}` is invalid.\nA basic function test threw the following error:\n{e}')
raise RuntimeError(f'The predict_model function of class `{model_class0}` is invalid.\nA basic function test threw the following error:\n{e}')
if not all(base.is_numeric(tmp)):
raise RuntimeError('The output of predict_model is expected to be numeric.')
if not (len(tmp) == 2):
Expand Down Expand Up @@ -361,10 +361,9 @@ def prebuilt_predict_model(model):
try:
from keras.models import Model
if isinstance(model, Model):
def predict_fn(m, x):
def predict_fn(m,x):
pred = m.predict(x)
return pred.reshape(pred.shape[0],)

return predict_fn
except:
pass
Expand All @@ -373,6 +372,7 @@ def predict_fn(m, x):


def compute_time(timing_list):

timing_secs = {
f'{key}': (timing_list[key] - timing_list[prev_key]).total_seconds()
for key, prev_key in zip(list(timing_list.keys())[1:], list(timing_list.keys())[:-1])
Expand Down

0 comments on commit 4e751dd

Please sign in to comment.