Skip to content

Commit

Permalink
fixed types
Browse files Browse the repository at this point in the history
  • Loading branch information
RahulDey12 committed Jun 16, 2024
1 parent abe5a03 commit fb8288c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
16 changes: 9 additions & 7 deletions src/Encoder.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

use Rahul900day\Tiktoken\Contracts\BpeContract;
use Rahul900day\Tiktoken\Enums\SpecialToken;
use Rahul900day\Tiktoken\Exceptions\InvalidPatternException;
use Rahul900day\Tiktoken\Exceptions\RankNotFoundException;
use Rahul900day\Tiktoken\Exceptions\SpecialTokenNotAllowedException;
use Rahul900day\Tiktoken\Exceptions\TiktokenException;
Expand Down Expand Up @@ -59,12 +60,12 @@ public function encodeOrdinaryBatch(array $texts): array

/**
* @param string[]|'all' $allowedSpecial
* @param string[]|'all' $disallowedSpecial
* @return int[]
*
* @throws Exceptions\InvalidPatternException
* @throws SpecialTokenNotAllowedException
* @throws SpecialTokenNotAllowedException|InvalidPatternException|TiktokenException
*/
public function encode(string $text, array|string $allowedSpecial = [], string $disallowedSpecial = 'all'): array
public function encode(string $text, array|string $allowedSpecial = [], string|array $disallowedSpecial = 'all'): array
{
if ($allowedSpecial === 'all') {
$allowedSpecial = $this->getSpecialTokensKeys();
Expand All @@ -74,7 +75,8 @@ public function encode(string $text, array|string $allowedSpecial = [], string $
$disallowedSpecial = array_diff($this->getSpecialTokensKeys(), $allowedSpecial);
}

// @phpstan-ignore argument.type


if (count($disallowedSpecial) > 0) {
$regex = SpecialToken::getRegex($disallowedSpecial);

Expand All @@ -89,12 +91,12 @@ public function encode(string $text, array|string $allowedSpecial = [], string $
/**
* @param array<string> $texts
* @param string[]|'all' $allowedSpecial
* @param string[]|'all' $disallowedSpecial
* @return array<int[]>
*
* @throws Exceptions\InvalidPatternException
* @throws SpecialTokenNotAllowedException
* @throws SpecialTokenNotAllowedException|InvalidPatternException|TiktokenException
*/
public function encodeBatch(array $texts, array|string $allowedSpecial = [], string $disallowedSpecial = 'all'): array
public function encodeBatch(array $texts, array|string $allowedSpecial = [], string|array $disallowedSpecial = 'all'): array
{
$result = [];

Expand Down
8 changes: 6 additions & 2 deletions src/Enums/SpecialToken.php
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@ enum SpecialToken: string
case ENDOFPROMPT = '<|endofprompt|>';

/**
* @param array<string> $tokens
* @param array<string>|string $tokens
*
* @throws InvalidPatternException
*/
public static function getRegex(array $tokens): string
public static function getRegex(array|string $tokens): string
{
if (is_string($tokens)) {
$tokens = [$tokens];
}

$parts = array_map('preg_quote', $tokens);
$regex = '/'.implode('|', $parts).'/u';

Expand Down

0 comments on commit fb8288c

Please sign in to comment.