Skip to content

Commit

Permalink
multiprompt fix
Browse files Browse the repository at this point in the history
  • Loading branch information
eps696 committed Feb 9, 2023
1 parent 7844ce0 commit 6d2faeb
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 25 deletions.
1 change: 1 addition & 0 deletions src/_sdrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def generate(z_, c_, uc=uc, cw=None, img=None, mask=None, thresh=True, out_lat=F
extra_args = {**extra_args, 'x_frozen': img, 'mask': mask}
c_count = c_.shape[0]
if cw is None: cw = [1.] * c_count
if not isinstance(cw, list): cw = [cw]
extra_args['cond_weights'] = [c / sum(cw) for c in cw]
extra_args['cond_counts'] = [c_count,]
samples = sampling_fn(model_cfg, z_, sigmas, extra_args=extra_args, disable=False) # [1,4,64,64]
Expand Down
41 changes: 16 additions & 25 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,32 +146,23 @@ def read_multitext(in_txt, prefix=None, postfix=None, flat=False):
lines = read_txt(in_txt)
if prefix is not None: prefix = read_txt(prefix)
if postfix is not None: postfix = read_txt(postfix)
if flat:
prompts = []
for i, tt in enumerate(lines):
tt = tt.strip()
if len(tt) == 0:
prompts.append('')
elif tt[0] != '#':
if prefix is not None: tt = prefix[i % len(prefix)] + ' | ' + tt
if postfix is not None: tt = tt + ' | ' + postfix[i % len(postfix)]
prompts.append(tt)
prompts = [parse_line(tt) for tt in lines if tt.strip()[0] != '#']
maxlen = 0
for prompt in prompts:
maxlen = max(maxlen, len(prompt))
for i in range(len(prompts)):
if len(prompts[i]) < maxlen:
prompts[i] += [('', 1e-4)] * (maxlen - len(prompts[i]))
if prefix is not None:
prompts = [parse_line(prefix[i % len(prefix)]) + prompts[i] for i in range(len(prompts))]
if postfix is not None:
prompts = [prompts[i] + parse_line(postfix[i % len(postfix)]) for i in range(len(prompts))]
weights = [[p[1] for p in prompt] for prompt in prompts]
prompts = [[p[0] for p in prompt] for prompt in prompts]
if flat is True:
prompts = [' | '.join([p for p in prompt if len(p)>0]) for prompt in prompts]
weights = [1.] * len(prompts)
else:
prompts = [parse_line(tt) for tt in lines if tt.strip()[0] != '#']
maxlen = 0
for prompt in prompts:
maxlen = max(maxlen, len(prompt))
for i in range(len(prompts)):
if len(prompts[i]) < maxlen:
prompts[i] += [('', 1e-4)] * (maxlen - len(prompts[i]))
if prefix is not None:
prompts = [parse_line(prefix[i % len(prefix)]) + prompts[i] for i in range(len(prompts))]
if postfix is not None:
prompts = [prompts[i] + parse_line(postfix[i % len(postfix)]) for i in range(len(prompts))]
weights = [[p[1] for p in prompt] for prompt in prompts]
prompts = [[p[0] for p in prompt] for prompt in prompts]
return prompts, weights # two lists if flat, or two lists of lists if multi
return prompts, weights # two lists [if flat], or two lists of lists [if multi]

def unique_prefix(out_dir):
dirlist = sorted(os.listdir(out_dir), reverse=True) # sort reverse alphabetically until we find max+1
Expand Down

0 comments on commit 6d2faeb

Please sign in to comment.