Skip to content

Commit

Permalink
Share common strings in the engine database to save some space.
Browse files Browse the repository at this point in the history
  • Loading branch information
TeamSpen210 committed Dec 3, 2024
1 parent acf1857 commit a0a3a05
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 30 deletions.
1 change: 1 addition & 0 deletions docs/source/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Version (dev)
* Allow entities to delete multiple keyvalues at once.
* Fix silent buttons trying to pack invalid `Buttons.snd0` soundscripts.
* Handle entities being added/removed during iteration of :py:meth:`VMF.search() <srctools.vmf.VMF.search>`.
* Share common strings in the engine database to save some space.

-------------
Version 2.4.1
Expand Down
105 changes: 75 additions & 30 deletions src/srctools/_engine_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
from typing import (
IO, TYPE_CHECKING, AbstractSet, Callable, Collection, Dict, Final, FrozenSet,
Iterable, List, Mapping, Optional, Set, Tuple, Union,
Iterable, List, Mapping, Optional, Set, Tuple, Union, Counter
)
from typing_extensions import TypeAlias
from enum import IntFlag
Expand All @@ -30,23 +30,25 @@
'serialise', 'unserialise',
]

_fmt_8bit: Final = Struct('>B')
_fmt_16bit: Final = Struct('>H')
_fmt_32bit: Final = Struct('>I')
_fmt_double: Final = Struct('>d')
_fmt_header: Final = Struct('>BI')
_fmt_ent_header: Final = Struct('>BBBBBB')
_fmt_block_pos: Final = Struct('>IH')
_fmt_8bit: Final = Struct('<B')
_fmt_16bit: Final = Struct('<H')
_fmt_32bit: Final = Struct('<I')
_fmt_double: Final = Struct('<d')
_fmt_header: Final = Struct('<BI')
_fmt_ent_header: Final = Struct('<BBBBBB')
_fmt_block_pos: Final = Struct('<IH')


# Version number for the format.
BIN_FORMAT_VERSION: Final = 8
BIN_FORMAT_VERSION: Final = 9
TAG_EMPTY: Final[FrozenSet[str]] = frozenset() # This is a singleton.
# Soft limit on the number of bytes for each block, needs tuning.
MAX_BLOCK_SIZE: Final = 2048
# When writing arrays of strings, it's much more efficient to read the whole thing, decode then
# split by a character rather than read sizes individually.
STRING_SEP: Final = '\x1F' # UNIT SEPARATOR
# Number of strings to keep in the shared database.
SHARED_STRINGS: Final = 512


class EntFlags(IntFlag):
Expand Down Expand Up @@ -190,20 +192,26 @@ class BinStrDict:
Each unique string is assigned a 2-byte index into the list.
"""
def __init__(self, database: Iterable[str]) -> None:
def __init__(self, database: Iterable[str], base: Optional['BinStrDict']) -> None:
self._dict: Dict[str, int] = {
name: ind for ind, name
in enumerate(database)
}
if len(self._dict) >= (1 << 16):
# If no base dict, this is for CBaseEntity, so set it to the real dict,
# so __call__() won't add SHARED_STRINGS to the index.
self.base_dict: Dict[str, int] = base._dict if base is not None else self._dict
if len(self._dict) + len(self.base_dict) >= (1 << 16):
raise ValueError("Too many items in dictionary!")

def __call__(self, string: str) -> bytes:
"""Get the index for a string.
The result is the two bytes that represent the string.
"""
return _fmt_16bit.pack(self._dict[string])
if string in self.base_dict:
return _fmt_16bit.pack(self.base_dict[string])
else:
return _fmt_16bit.pack(SHARED_STRINGS + self._dict[string])

def serialise(self, file: IO[bytes]) -> None:
"""Convert this to a stream of bytes."""
Expand All @@ -220,15 +228,17 @@ def serialise(self, file: IO[bytes]) -> None:
file.write(data)

@classmethod
def unserialise(cls, file: IO[bytes]) -> Callable[[], str]:
def unserialise(cls, file: IO[bytes], base: List[str]) -> Tuple[List[str], Callable[[], str]]:
"""Read the dictionary from a file.
This returns a function which reads
This returns the dict, and a function which reads
a string from a file at the current point.
"""
[length] = _fmt_16bit.unpack(file.read(2))
inv_list = lzma.decompress(file.read(length)).decode('utf8').split(STRING_SEP)
return make_lookup(file, inv_list)
# This could branch on the index to avoid the concatenation, but this should be
# faster, and the dict will only be around temporarily anyway.
return inv_list, make_lookup(file, base + inv_list)

@staticmethod
def read_tags(file: IO[bytes], from_dict: Callable[[], str]) -> FrozenSet[str]:
Expand All @@ -253,9 +263,15 @@ def write_tags(

class EngineDB(_EngineDBProto):
"""Unserialised database, which will be parsed progressively as required."""
def __init__(self, ent_map: Dict[str, Union[EntityDef, int]], unparsed: List[Tuple[Iterable[str], bytes]]) -> None:
def __init__(
self,
ent_map: Dict[str, Union[EntityDef, int]],
base_strings: List[str],
unparsed: List[Tuple[Iterable[str], bytes]],
) -> None:
self.ent_map = ent_map
self.unparsed = unparsed
self.base_strings = base_strings
self.fgd: Optional[FGD] = None

def get_classnames(self) -> AbstractSet[str]:
Expand Down Expand Up @@ -296,7 +312,7 @@ def _parse_block(self, index: int) -> None:
apply_bases = []

file = io.BytesIO(data)
from_dict = BinStrDict.unserialise(file)
_, from_dict = BinStrDict.unserialise(file, self.base_strings)
for classname in classes:
self.ent_map[classname.casefold()] = ent = ent_unserialise(file, classname, from_dict)
if ent.bases:
Expand All @@ -307,11 +323,12 @@ def _parse_block(self, index: int) -> None:
self.unparsed[index] = ((), b'')
for ent in apply_bases:
# Apply bases. This should just be for aliases, which are likely also in this block.
# Importantly, we've already put those in ent_map, so this won't recurse if they
# are in our block.
ent.bases = [
base if isinstance(base, EntityDef) else self.get_ent(base)
for base in ent.bases
]
ent.bases.append(cbase_entity)

def get_fgd(self) -> FGD:
"""Parse all the blocks and make an FGD."""
Expand Down Expand Up @@ -580,6 +597,25 @@ def record_strings(string: str) -> bytes:
ent_to_string[ent] = ent_strings = set()
ent_serialise(ent, dummy_file, record_strings)
ent_to_size[ent] = dummy_file.tell()

assert ent.classname.casefold() == '_cbaseentity_'
base_strings = ent_to_string[ent]
print(f'{SHARED_STRINGS-len(base_strings)}/{SHARED_STRINGS} shared strings used.')

# Find common strings, move them to the base set.
string_counts = Counter[str]()
for strings in ent_to_string.values():
string_counts.update(strings)
# Shared strings might already be in base, so break early once we hit the quota.
# At most we'll need to add SHARED_STRINGS different items.
for st, count in string_counts.most_common(SHARED_STRINGS):
if len(base_strings) >= SHARED_STRINGS:
break
base_strings.add(st)
for strings in ent_to_string.values():
if strings is not base_strings:
strings -= base_strings

return ent_to_string, ent_to_size


Expand Down Expand Up @@ -655,8 +691,15 @@ def add_ent(self, ent: EntityDef) -> None:
all_blocks.sort(key=lambda block: len(block.ents))

for block in all_blocks:
efficency = len(block.stringdb) / sum(map(len, map(ent_to_string.__getitem__, block.ents)))
print(f'{block.bytesize} bytes = {len(block.ents)} = {1/efficency:.02%}')
if block.stringdb:
efficency = format(
sum(map(len, map(ent_to_string.__getitem__, block.ents)))
/ len(block.stringdb),
'.02%'
)
else:
efficency = 'All shared'
print(f'{block.bytesize} bytes = {len(block.ents)} = {efficency}')
print(len(all_blocks), 'blocks')
return [
(block.ents, block.stringdb)
Expand All @@ -676,7 +719,7 @@ def serialise(fgd: FGD, file: IO[bytes]) -> None:
print('Computing string sizes...')
# We need the database for CBaseEntity, but not to include it with anything else.
ent_to_string, ent_to_size = compute_ent_strings(itertools.chain(all_ents, [CBaseEntity]))
CBaseEntity_strings = ent_to_string[CBaseEntity]
base_strings = ent_to_string[CBaseEntity]

# For every pair of entities (!), compute the number of overlapping ents.
print('Computing overlaps...')
Expand Down Expand Up @@ -708,20 +751,22 @@ def serialise(fgd: FGD, file: IO[bytes]) -> None:
block_ents.sort(key=operator.attrgetter('classname'))
for ent in block_ents:
assert '\x1b' not in ent.classname, ent
classnames = lzma.compress(STRING_SEP.join(ent.classname for ent in block_ents).encode('utf8'))
# Not worth it to compress these.
classnames = STRING_SEP.join(ent.classname for ent in block_ents).encode('utf8')
file.write(_fmt_16bit.pack(len(classnames)))
file.write(classnames)
deferred.defer(('block', id(block_ents)), _fmt_block_pos, write=True)

# First, write CBaseEntity specially.
dictionary = BinStrDict(CBaseEntity_strings)
dictionary.serialise(file)
ent_serialise(CBaseEntity, file, dictionary)
# First, write the base strings and CBaseEntity specially.
assert len(base_strings) == SHARED_STRINGS, len(base_strings)
base_dict = BinStrDict(base_strings, None)
base_dict.serialise(file)
ent_serialise(CBaseEntity, file, base_dict)

# Then write each block and then each entity.
for block_ents, block_stringdb in blocks:
block_off = file.tell()
dictionary = BinStrDict(block_stringdb)
dictionary = BinStrDict(block_stringdb, base_dict)
dictionary.serialise(file)
for ent in block_ents:
ent_serialise(ent, file, dictionary)
Expand All @@ -748,19 +793,19 @@ def unserialise(file: IO[bytes]) -> _EngineDBProto:

for block_id in range(block_count):
[cls_size] = _fmt_16bit.unpack(file.read(2))
classnames = lzma.decompress(file.read(cls_size)).decode('utf8').split(STRING_SEP)
classnames = file.read(cls_size).decode('utf8').split(STRING_SEP)
block_classnames.append(classnames)
for name in classnames:
ent_map[name.casefold()] = block_id
off, size = _fmt_block_pos.unpack(file.read(_fmt_block_pos.size))
positions.append((classnames, off, size))

# Read CBaseEntity.
from_dict = BinStrDict.unserialise(file)
base_strings, from_dict = BinStrDict.unserialise(file, [])
ent_map['_cbaseentity_'] = ent_unserialise(file, '_CBaseEntity_', from_dict)

for classnames, off, size in positions:
file.seek(off)
unparsed.append((classnames, file.read(size)))

return EngineDB(ent_map, unparsed)
return EngineDB(ent_map, base_strings, unparsed)
Binary file modified src/srctools/fgd.lzma
Binary file not shown.

0 comments on commit a0a3a05

Please sign in to comment.