Skip to content

Commit

Permalink
bar progress and high level logging
Browse files Browse the repository at this point in the history
  • Loading branch information
Panizghi committed Aug 25, 2024
1 parent 20ab166 commit 229331c
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 42 deletions.
123 changes: 82 additions & 41 deletions src/main/python/safetensors/json_to_bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,73 +3,114 @@
import os
import argparse
import gzip
import logging
from safetensors.torch import save_file, load_file
from tqdm import tqdm # Import tqdm for progress bars

# Set up logging to both console and file
log_file_path = 'process_log.log'
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[
logging.FileHandler(log_file_path), # Log detailed information to a file
logging.StreamHandler() # Log high-level information to the console
])

# Set up argument parser
parser = argparse.ArgumentParser(description='Process vectors and docids from JSONL or GZ files.')
parser.add_argument('--input', required=True, help='Path to the input JSONL or GZ file')
parser = argparse.ArgumentParser(description='Process vectors and docids from JSON, JSONL, or GZ files.')
parser.add_argument('--input', required=True, help='Path to the input JSON, JSONL, or GZ file')
parser.add_argument('--output', required=True, help='Path to the output directory')
parser.add_argument('--overwrite', action='store_true', help='Overwrite existing files if they already exist')

args = parser.parse_args()

# Define paths
input_file_path = args.input
output_directory = args.output

# Ensure the output directory exists
if not os.path.exists(output_directory):
os.makedirs(output_directory)

# Check if the input file is a .gz file and convert it to .jsonl if necessary
if input_file_path.endswith('.gz'):
with gzip.open(input_file_path, 'rt') as gz_file:
jsonl_file_path = input_file_path.replace('.gz', '.jsonl')
with open(jsonl_file_path, 'w') as jsonl_file:
for line in gz_file:
jsonl_file.write(line)
input_file_path = jsonl_file_path
# Ensure the input file exists
if not os.path.exists(input_file_path):
logging.error(f"Input file '{input_file_path}' not found.")
raise FileNotFoundError(f"Input file '{input_file_path}' not found.")

# Check if the input file is a .jsonl file
elif not input_file_path.endswith('.jsonl'):
raise ValueError("Input file must be a .jsonl or .gz file")
# Ensure the output directory exists or create it
try:
os.makedirs(output_directory, exist_ok=True)
except OSError as e:
logging.error(f"Failed to create output directory '{output_directory}': {e}")
raise OSError(f"Failed to create output directory '{output_directory}': {e}")

# Get the base name of the input file for output file names
base_name = os.path.basename(input_file_path).replace('.jsonl', '')
base_name = os.path.basename(input_file_path).replace('.jsonl', '').replace('.gz', '').replace('.json', '')

vectors_path = os.path.join(output_directory, f'{base_name}_vectors.safetensors')
docids_path = os.path.join(output_directory, f'{base_name}_docids.safetensors')

if not args.overwrite:
if os.path.exists(vectors_path) or os.path.exists(docids_path):
logging.error(f"Output files '{vectors_path}' or '{docids_path}' already exist. Use '--overwrite' to overwrite.")
raise FileExistsError(f"Output files '{vectors_path}' or '{docids_path}' already exist. Use '--overwrite' to overwrite.")

# Initialize lists to hold data
vectors = []
docids = []

# Process the JSONL file to extract vectors and docids
with open(input_file_path, 'r') as file:
for line in file:
entry = json.loads(line)
# Ensure that the vector starts with a valid number
if isinstance(entry['vector'][0], float):
vectors.append(entry['vector'])
docid = entry['docid']
docid_ascii = [ord(char) for char in docid] # Convert docid to ASCII values
docids.append(docid_ascii)
else:
print(f"Skipped invalid vector entry with docid: {entry['docid']}")
# Determine file opener based on file extension
if input_file_path.endswith('.gz'):
file_opener = gzip.open
elif input_file_path.endswith('.jsonl') or input_file_path.endswith('.json'):
file_opener = open
else:
logging.error("Input file must be a .json, .jsonl, or .gz file")
raise ValueError("Input file must be a .json, .jsonl, or .gz file")

# Get total number of lines for tqdm if possible
try:
total_lines = sum(1 for _ in file_opener(input_file_path, 'rt'))
except Exception:
total_lines = None

# Process the JSON, JSONL, or GZ file to extract vectors and docids
try:
with file_opener(input_file_path, 'rt') as file:
for line in tqdm(file, total=total_lines, desc="Processing lines"):
try:
entry = json.loads(line)
if isinstance(entry.get('vector', [None])[0], float):
vectors.append(entry['vector'])
docid = entry['docid']
docid_ascii = [ord(char) for char in docid] # Convert docid to ASCII values
docids.append(docid_ascii)
else:
logging.warning(f"Skipped invalid vector entry with docid: {entry.get('docid', 'N/A')}")
except (json.JSONDecodeError, KeyError) as e:
logging.warning(f"Skipped invalid JSON entry: {e}")
except IOError as e:
logging.error(f"Failed to read the input file '{input_file_path}': {e}")
raise IOError(f"Failed to read the input file '{input_file_path}': {e}")

# Convert lists to tensors
vectors_tensor = torch.tensor(vectors, dtype=torch.float64)
docids_tensor = torch.nn.utils.rnn.pad_sequence([torch.tensor(d, dtype=torch.int64) for d in docids], batch_first=True)

# Save the tensors to SafeTensors files
save_file({'vectors': vectors_tensor}, vectors_path)
save_file({'docids': docids_tensor}, docids_path)

print(f"Saved vectors to {vectors_path}")
print(f"Saved docids to {docids_path}")

# Load vectors and docids
loaded_vectors = load_file(vectors_path)['vectors']
loaded_docids = load_file(docids_path)['docids']
try:
save_file({'vectors': vectors_tensor}, vectors_path)
save_file({'docids': docids_tensor}, docids_path)
logging.info(f"Saved vectors to {vectors_path}")
logging.info(f"Saved docids to {docids_path}")
except IOError as e:
logging.error(f"Failed to save tensors: {e}")
raise IOError(f"Failed to save tensors: {e}")

print(f"Loaded vectors: {loaded_vectors}")
print(f"Loaded document IDs (ASCII): {loaded_docids}")
# Load vectors and docids for verification
try:
loaded_vectors = load_file(vectors_path)['vectors']
loaded_docids = load_file(docids_path)['docids']
logging.info(f"Loaded vectors from {vectors_path}")
logging.info(f"Loaded document IDs (ASCII) from {docids_path}")
# Log detailed information to the file
logging.getLogger().handlers[0].setLevel(logging.DEBUG)
logging.debug(f"Loaded vectors: {loaded_vectors}")
logging.debug(f"Loaded document IDs (ASCII): {loaded_docids}")
except IOError as e:
logging.error(f"Failed to load tensors: {e}")
raise IOError(f"Failed to load tensors: {e}")
3 changes: 2 additions & 1 deletion src/main/python/safetensors/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ six==1.16.0
sympy==1.12
torch==2.3.0
typing_extensions==4.11.0
urllib3==2.2.1
urllib3==2.2.1
tqdm==4.66.5

0 comments on commit 229331c

Please sign in to comment.