Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve GGUF metadata handling #6082

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 116 additions & 61 deletions modules/metadata_gguf.py
Original file line number Diff line number Diff line change
@@ -1,92 +1,147 @@
import os
import sys
import struct
from enum import IntEnum

from io import BufferedReader
from typing import Union

class GGUFValueType(IntEnum):
UINT8 = 0
INT8 = 1
UINT16 = 2
INT16 = 3
UINT32 = 4
INT32 = 5
# Occasionally check to ensure this class is consistent with gguf
UINT8 = 0
INT8 = 1
UINT16 = 2
INT16 = 3
UINT32 = 4
INT32 = 5
FLOAT32 = 6
BOOL = 7
STRING = 8
ARRAY = 9
UINT64 = 10
INT64 = 11
BOOL = 7
STRING = 8
ARRAY = 9
UINT64 = 10
INT64 = 11
FLOAT64 = 12


_simple_value_packing = {
GGUFValueType.UINT8: "<B",
GGUFValueType.INT8: "<b",
GGUFValueType.UINT16: "<H",
GGUFValueType.INT16: "<h",
GGUFValueType.UINT32: "<I",
GGUFValueType.INT32: "<i",
GGUFValueType.FLOAT32: "<f",
GGUFValueType.UINT64: "<Q",
GGUFValueType.INT64: "<q",
GGUFValueType.FLOAT64: "<d",
GGUFValueType.BOOL: "?",
# the GGUF format versions that this module supports
SUPPORTED_GGUF_VERSIONS = [3]

# GGUF only supports execution on little or big endian machines
if sys.byteorder not in ['little', 'big']:
raise ValueError(
"host is not little or big endian - GGUF is unsupported"
)

# arguments for struct.unpack() based on gguf value type
value_packing: dict = {
GGUFValueType.UINT8: "=B",
GGUFValueType.INT8: "=b",
GGUFValueType.UINT16: "=H",
GGUFValueType.INT16: "=h",
GGUFValueType.UINT32: "=I",
GGUFValueType.INT32: "=i",
GGUFValueType.FLOAT32: "=f",
GGUFValueType.UINT64: "=Q",
GGUFValueType.INT64: "=q",
GGUFValueType.FLOAT64: "=d",
GGUFValueType.BOOL: "?"
}

value_type_info = {
GGUFValueType.UINT8: 1,
GGUFValueType.INT8: 1,
GGUFValueType.UINT16: 2,
GGUFValueType.INT16: 2,
GGUFValueType.UINT32: 4,
GGUFValueType.INT32: 4,
# length in bytes for each gguf value type
value_lengths: dict = {
GGUFValueType.UINT8: 1,
GGUFValueType.INT8: 1,
GGUFValueType.UINT16: 2,
GGUFValueType.INT16: 2,
GGUFValueType.UINT32: 4,
GGUFValueType.INT32: 4,
GGUFValueType.FLOAT32: 4,
GGUFValueType.UINT64: 8,
GGUFValueType.INT64: 8,
GGUFValueType.UINT64: 8,
GGUFValueType.INT64: 8,
GGUFValueType.FLOAT64: 8,
GGUFValueType.BOOL: 1,
GGUFValueType.BOOL: 1
}

def unpack(value_type: GGUFValueType, file: BufferedReader):
return struct.unpack(
value_packing.get(value_type),
file.read(value_lengths.get(value_type))
)[0]

def get_single(value_type, file):
def get_single(
value_type: GGUFValueType,
file: BufferedReader
) -> Union[str, int, float, bool]:
"""Read a single value from an open file"""
if value_type == GGUFValueType.STRING:
value_length = struct.unpack("<Q", file.read(8))[0]
value = file.read(value_length)
string_length = unpack(GGUFValueType.UINT64, file=file)
value = file.read(string_length)
# officially, strings that cannot be decoded into utf-8 are invalid
try:
value = value.decode('utf-8')
value = value.decode("utf-8")
except:
pass
else:
type_str = _simple_value_packing.get(value_type)
bytes_length = value_type_info.get(value_type)
value = struct.unpack(type_str, file.read(bytes_length))[0]

value = unpack(value_type, file=file)
return value

def load_metadata(
fn: Union[os.PathLike[str], str]
) -> dict[str, Union[str, int, float, bool, list]]:
"""
Given a path to a GGUF file, peek at its header for metadata

def load_metadata(fname):
metadata = {}
with open(fname, 'rb') as file:
GGUF_MAGIC = struct.unpack("<I", file.read(4))[0]
GGUF_VERSION = struct.unpack("<I", file.read(4))[0]
ti_data_count = struct.unpack("<Q", file.read(8))[0]
kv_data_count = struct.unpack("<Q", file.read(8))[0]
Return a dictionary where all keys are strings, and values can be
strings, ints, floats, bools, or lists
"""

if GGUF_VERSION == 1:
raise Exception('You are using an outdated GGUF, please download a new one.')
metadata: dict[str, Union[str, int, float, bool, list]] = {}
with open(fn, "rb") as file:
magic = file.read(4)

for i in range(kv_data_count):
key_length = struct.unpack("<Q", file.read(8))[0]
key = file.read(key_length)
if magic != b"GGUF":
raise ValueError(
"your model file is not a valid GGUF file "
f"(magic number mismatch, got {magic}, "
"expected b'GGUF')"
)

version = unpack(GGUFValueType.UINT32, file=file)

value_type = GGUFValueType(struct.unpack("<I", file.read(4))[0])
if value_type == GGUFValueType.ARRAY:
ltype = GGUFValueType(struct.unpack("<I", file.read(4))[0])
length = struct.unpack("<Q", file.read(8))[0]
if version not in SUPPORTED_GGUF_VERSIONS:
raise ValueError(
f"your model file reports GGUF version {version}, but "
f"only versions {SUPPORTED_GGUF_VERSIONS} "
"are supported. re-convert your model or download a newer "
"version"
)

tensor_count = unpack(GGUFValueType.UINT64, file=file)
metadata_kv_count = unpack(GGUFValueType.UINT64, file=file)

arr = [get_single(ltype, file) for _ in range(length)]
metadata[key.decode()] = arr
for _ in range(metadata_kv_count):
key_length = unpack(GGUFValueType.UINT64, file=file)
key = file.read(key_length)
value_type = GGUFValueType(
unpack(GGUFValueType.UINT32, file=file)
)
if value_type == GGUFValueType.ARRAY:
array_value_type = GGUFValueType(
unpack(GGUFValueType.UINT32, file=file)
)
# array_length is the number of items in the array
array_length = unpack(GGUFValueType.UINT64, file=file)
array = [
get_single(
array_value_type,
file
) for _ in range(array_length)
]
metadata[key.decode()] = array
else:
value = get_single(value_type, file)
value = get_single(
value_type,
file
)
metadata[key.decode()] = value

return metadata