Skip to content

Commit

Permalink
Updated phpstan types
Browse files Browse the repository at this point in the history
  • Loading branch information
RahulDey12 committed Jun 16, 2024
1 parent 199a5c2 commit 6d40c1f
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 7 deletions.
43 changes: 39 additions & 4 deletions src/Bpe.php
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public function encode(string $text, array $allowedSpecial): array
while (true) {
// Find the next allowed special rank, if any
if (preg_match($this->specialRegex, mb_substr($text, $start_find), $matches, PREG_OFFSET_CAPTURE)) {
/** @var array $next_special */
/** @var array{0: string, 1: int} $next_special */
$next_special = $matches[0];
$match_str = mb_substr($text, $start_find + $next_special[1], mb_strlen((string) $next_special[0]));

Expand Down Expand Up @@ -122,6 +122,11 @@ public function encodeOrdinary(string $text): array
return $ranks;
}

/**
* @param int[] $bytes
* @return int[]
* @throws Exception
*/
private function bpe(array $bytes): array
{
$bytePairs = $this->initializePairs($bytes);
Expand All @@ -131,12 +136,15 @@ private function bpe(array $bytes): array
$index = $minRank[1];

if ($index > 0) {
// @phpstan-ignore-next-line
ArrayUtil::at($bytePairs, $index - 1)[1] = $this->calculateMergedRank($bytes, $bytePairs, $index - 1);
}

// @phpstan-ignore-next-line
ArrayUtil::at($bytePairs, $index)[1] = $this->calculateMergedRank($bytes, $bytePairs, $index);
ArrayUtil::unsetAt($bytePairs, $index + 1);

// @phpstan-ignore-next-line
$minRank = $this->getMinRankPair(ArrayUtil::getSegment($bytePairs, 0, count($bytePairs) - 1));
}

Expand All @@ -148,6 +156,10 @@ private function bpe(array $bytes): array
}, $this->getAllPairs($bytePairs));
}

/**
* @param int[] $bytes
* @return array<array{0: int, 1: int}>
*/
private function initializePairs(array $bytes): array
{
$parts = [];
Expand All @@ -165,6 +177,10 @@ private function initializePairs(array $bytes): array
return $parts;
}

/**
* @param array<array{0: int, 1: int}> $parts
* @return array{0: int, 1: int}
*/
private function getMinRankPair(array $parts): array
{
$minRank = [self::MAX_INT, self::MAX_INT];
Expand All @@ -174,25 +190,40 @@ private function getMinRankPair(array $parts): array
$minRank = [$rank, $index];
}
}

return $minRank;
}

/**
* @param int[] $bytes
* @param array<array{0: int, 1: int}> $parts
* @param int $startIndex
* @return int
*/
private function calculateMergedRank(array $bytes, array $parts, int $startIndex): int
{
if ($startIndex + 3 >= count($parts)) {
return self::MAX_INT;
}

$start = ArrayUtil::at($parts, $startIndex)[0];
$stop = ArrayUtil::at($parts, $startIndex + 3)[0];
/** @var array{0: int, 1: int} $startPart */
$startPart = ArrayUtil::at($parts, $startIndex);
/** @var array{0: int, 1: int} $stopPart */
$stopPart = ArrayUtil::at($parts, $startIndex + 3); // @phpstan-ignore-line

$start = $startPart[0];
$stop = $stopPart[0];

return $this->getRank(ArrayUtil::getSegment($bytes, $start, $stop)) ?? self::MAX_INT;
}

/**
* @param non-empty-array<int[]> $parts
* @return array<array<int[]>>
*/
private function getAllPairs(array $parts): array
{
$pairs = [];
/** @var int[] $previousPart */
$previousPart = array_shift($parts);

foreach ($parts as $part) {
Expand All @@ -203,6 +234,10 @@ private function getAllPairs(array $parts): array
return $pairs;
}

/**
* @param int[]|string $bytes
* @return int|null
*/
private function getRank(array|string $bytes): ?int
{
if (is_array($bytes)) {
Expand Down
2 changes: 1 addition & 1 deletion src/Contracts/BpeContract.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ interface BpeContract
{
/**
* @param string[] $allowedSpecial
* @return array{0: int[], 1, int}
* @return array{0: int[], 1: int}
*/
public function encode(string $text, array $allowedSpecial): array;

Expand Down
54 changes: 53 additions & 1 deletion src/Loaders/DataGymLoader.php
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,20 @@ public function load(
return $bpeRanks;
}

/**
* @param string $startChar
* @param string $endChar
* @return int[]
*/
private function createByteRange(string $startChar, string $endChar): array
{
return range(mb_ord($startChar), mb_ord($endChar));
}

/**
* @param int[] $byteArray
* @return array<string, int>
*/
private function byteToCharMap(array $byteArray): array
{
$byteToCharMap = [];
Expand All @@ -59,6 +68,11 @@ private function byteToCharMap(array $byteArray): array
return $byteToCharMap;
}

/**
* @param int[] $rankToIntByte
* @param array<string, int> $dataGymByteToByteMap
* @return void
*/
private function addBytesNotInRank(array &$rankToIntByte, array &$dataGymByteToByteMap): void
{
$unicodeCounter = 0;
Expand All @@ -75,11 +89,24 @@ private function addBytesNotInRank(array &$rankToIntByte, array &$dataGymByteToB
}
}

/**
* @param string $vocabBpeContents
* @return array<array{0: string, 1: string}>
*/
private function createBpeMerges(string $vocabBpeContents): array
{
return array_map(fn ($mergeStr): array => explode(' ', $mergeStr), array_slice(explode("\n", $vocabBpeContents), 1, -1));
$lines = explode("\n", $vocabBpeContents);
$mergeLines = array_slice($lines, 1, -1);
/** @var array<array{0: string, 1: string}> $merges */
$merges = array_map(fn($mergeStr) => explode(' ', $mergeStr, 2), $mergeLines);

return $merges;
}

/**
* @param int[] $rankToIntByte
* @return array<string, int>
*/
private function createBpeRanks(array $rankToIntByte): array
{
$bpeRanks = [];
Expand All @@ -90,10 +117,17 @@ private function createBpeRanks(array $rankToIntByte): array
return $bpeRanks;
}

/**
* @param array<array{0: string, 1: string}> $bpeMerges
* @param array<string, int> $bpeRanks
* @param array<string, int> $dataGymByteToByteMap
* @return void
*/
private function addMergeRanksToBpe(array $bpeMerges, array &$bpeRanks, array $dataGymByteToByteMap): void
{
foreach ($bpeMerges as [$first, $second]) {

/** @var int[] $tokenBytes */
$tokenBytes = [
...$this->decodeDataGym($first, $dataGymByteToByteMap),
...$this->decodeDataGym($second, $dataGymByteToByteMap),
Expand All @@ -104,6 +138,13 @@ private function addMergeRanksToBpe(array $bpeMerges, array &$bpeRanks, array $d
}
}

/**
* @param string $encoderJsonFile
* @param string|null $encoderJsonHash
* @param array<string, int> $dataGymByteToByteMap
* @return array<string, int>
* @throws \Rahul900day\Tiktoken\Exceptions\InvalidChecksumException
*/
private function loadEncoderJson(string $encoderJsonFile, ?string $encoderJsonHash, array $dataGymByteToByteMap): array
{
/** @var non-empty-array<string, int> $encoderJson */
Expand All @@ -120,13 +161,24 @@ private function loadEncoderJson(string $encoderJsonFile, ?string $encoderJsonHa
return $encoderJsonLoaded;
}

/**
* @param array<string, int> $bpeRanks
* @param array<string, int> $encoderJson
* @return void
* @throws \Exception
*/
private function validateBpeAndEncoderJsonRanks(array $bpeRanks, array $encoderJson): void
{
if ($bpeRanks !== $encoderJson) {
throw new \Exception("BPE Ranks & Encoder JSON Ranks Doesn't Match");
}
}

/**
* @param string|int $value
* @param array<string, int> $dataGymByteToByte
* @return int[]
*/
private function decodeDataGym(string|int $value, array $dataGymByteToByte): array
{
$bytes = [];
Expand Down
3 changes: 2 additions & 1 deletion src/Utils/ArrayUtil.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ final class ArrayUtil
* @template TValue
*
* @param array<TKey, TValue> $array
* @return TValue
*/
public static function &at(array &$array, int $at): mixed
{
Expand All @@ -38,7 +39,7 @@ public static function unsetAt(array &$array, int $at): void
* @template TKey
* @template TValue
*
* @param non-empty-array<TKey, TValue> $array
* @param array<TKey, TValue> $array
* @return array<TKey, TValue>
*/
public static function getSegment(array $array, int $start, int $end): array
Expand Down

0 comments on commit 6d40c1f

Please sign in to comment.