Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 67 additions & 220 deletions surya/common/surya/processor/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,225 +25,6 @@ def create_token_regex(tokens):
return regex


class InnerOCRTokenizer:
def __init__(
self,
special_tokens: Dict[str, list] | None = None,
qwen_tokenizer: Qwen2OriginalTokenizer | None = None,
**kwargs,
):
self.qwen_tokenizer = qwen_tokenizer
self.qwen_token_offset = len(qwen_tokenizer)

all_special_tokens = special_tokens.get("all", [])
self.SPECIAL_TOKEN_MAPPING = {}

idx = 0
for tag in all_special_tokens:
if tag in self.SPECIAL_TOKEN_MAPPING:
continue
self.SPECIAL_TOKEN_MAPPING[tag] = (
idx + self.qwen_token_offset
) # Assign token ID
idx += 1

self.REVERSE_SPECIAL_TOKEN_MAPPING = {
v: k for k, v in self.SPECIAL_TOKEN_MAPPING.items()
}
self.SPECIAL_TOKEN_OFFSET = idx
self.FORMAT_TAG_PATTERN = create_token_regex(special_tokens["formatting"])
self.MATH_TAG_PATTERN = create_token_regex(special_tokens["math_external"])
self.LAYOUT_TAG_PATTERN = create_token_regex(special_tokens["layout"])
self.TABLE_STRUCTURE_TAG_PATTERN = create_token_regex(
special_tokens["table_structure"]
)
self.SYSTEM_TAG_PATTERN = create_token_regex(special_tokens.get("system", []))
if not special_tokens.get("system", []):
logger.warning("Warning: No system tokens found in special_tokens")

self.MATH_TAG_START = "<math"
self.MATH_END_TAG = "</math>"

super().__init__(**kwargs)

@property
def vocab_size(self):
return (
65536 + self.SPECIAL_TOKEN_OFFSET
) # The highest codepoint is 65535, but we add 1 to account for the 0-indexing

def _tokenize(self, text: str) -> List[int]:
tokens = []
in_math = False
text = html.unescape(text) # Unescape html entities like &lt; in equations
while text:
# Look for EOS, PAD, etc. tokens
match = self.SYSTEM_TAG_PATTERN.search(text)
if match:
tag = match.group(1)
tokens.append(
self.SPECIAL_TOKEN_MAPPING[tag]
) # These are already offset
text = text[match.end() :]
continue

# Look for layout tokens
match = self.LAYOUT_TAG_PATTERN.search(text)
if match:
tag = match.group(1)
tokens.append(
self.SPECIAL_TOKEN_MAPPING[tag]
) # Layout tokens are already offset
text = text[match.end() :]
continue

match = self.TABLE_STRUCTURE_TAG_PATTERN.search(text)
if match:
tag = match.group(1)
tokens.append(self.SPECIAL_TOKEN_MAPPING[tag])
text = text[match.end() :]
continue

# Check for math tags
match = self.MATH_TAG_PATTERN.search(text)
if match:
# We found a tag
tag = match.group(1)
if tag.startswith(self.MATH_TAG_START):
in_math = True
elif tag == self.MATH_END_TAG:
in_math = False
tokens.append(
self.SPECIAL_TOKEN_MAPPING[tag] # Special tokens are already offset
) # Use special token ID
text = text[match.end() :]
continue

# Tokenize math content with qwen2 tokenizer
if in_math:
# If we're in a math block, check to see if we have a special math tag in the text
math_end_position = text.find(self.MATH_END_TAG)
math_str = text[:math_end_position] # Gets the math content
tokens += self.qwen_tokenizer(math_str)["input_ids"]
text = text[math_end_position:]
continue

# Check for formatting tags
match = self.FORMAT_TAG_PATTERN.search(text)
if match:
# We found a tag
tag = match.group(1)
tokens.append(
self.SPECIAL_TOKEN_MAPPING[tag] # Special tokens are already offset
) # Use special token ID
text = text[match.end() :]
continue

# General case, utf-16 tokenization
utf_16_tokens = self.text_to_utf16_numbers(text[0])
tokens += [
t + self.SPECIAL_TOKEN_OFFSET + self.qwen_token_offset
for t in utf_16_tokens
]
text = text[1:]

return tokens

def text_to_utf16_numbers(self, text: str):
"""Converts text to UTF-16 encoded numbers."""
utf16_bytes = text.encode(
"utf-16le"
) # Little-endian to simplify byte order handling
numbers = []

for i in range(0, len(utf16_bytes), 2):
# Combine two adjacent bytes into a single number
number = utf16_bytes[i] + (utf16_bytes[i + 1] << 8)
numbers.append(number)

return numbers

def utf16_numbers_to_text(self, numbers):
"""Converts UTF-16 numbers back to text."""
byte_array = bytearray()
for number in numbers:
byte_array.append(number & 0xFF) # Lower byte
byte_array.append((number >> 8) & 0xFF) # Upper byte

try:
text = byte_array.decode("utf-16le", errors="ignore")
except Exception as e:
logger.warning(f"Error decoding utf16: {e}")
text = ""

return text

def __call__(
self, texts: Union[str, List[str]], **kwargs
) -> Dict[str, List[List[int]]]:
"""Tokenizes text and returns input IDs."""
tokenized = []

if isinstance(texts, str):
texts = [texts]

for text in texts:
tokens = self._tokenize(text)
tokenized.append(tokens)

return {"input_ids": tokenized}

def decode(self, token_ids, **kwargs):
"""Decodes token IDs back to text."""
if isinstance(token_ids, (np.ndarray, torch.Tensor)):
token_ids = token_ids.tolist()

decoded_text = ""
token_buffer = []
decode_qwen = [False]

def decode_buffer():
nonlocal decoded_text, token_buffer, decode_qwen
if token_buffer:
if decode_qwen[0]:
decoded_text += self.qwen_tokenizer.decode(token_buffer)
else:
token_buffer = [
t - self.SPECIAL_TOKEN_OFFSET - self.qwen_token_offset
for t in token_buffer
]
decoded_text += self.utf16_numbers_to_text(token_buffer)

token_buffer = []
decode_qwen[0] = False

for t in token_ids:
if t < self.qwen_token_offset:
# This is for math tags
if token_buffer and token_buffer[-1] >= self.qwen_token_offset:
decode_buffer()
token_buffer.append(t)
decode_qwen[0] = True
elif t >= self.SPECIAL_TOKEN_OFFSET + self.qwen_token_offset:
if token_buffer and token_buffer[-1] < self.qwen_token_offset:
decode_buffer()
token_buffer.append(t) # We shift this down later on
decode_qwen[0] = False
elif t in self.REVERSE_SPECIAL_TOKEN_MAPPING:
decode_buffer()
decoded_text += self.REVERSE_SPECIAL_TOKEN_MAPPING[t]
decode_qwen[0] = False
else:
raise ValueError(
f'Unexpected token value while decoding, got "{t}" in token_ids {token_ids}'
)

# Detokenize remaining tokens
decode_buffer()

return decoded_text


class Qwen2Tokenizer(S3DownloaderMixin, Qwen2OriginalTokenizer):
pass

Expand Down Expand Up @@ -301,6 +82,39 @@ def _build_trie(
node.id = tid
return root

def _build_escape_patterns(self, math_token_to_rawid):
"""Build pattern list from vocab commands that start with control characters.

Scans the math vocab for LaTeX commands that could be corrupted by JSON
escape sequence interpretation (e.g., \\begin becomes <backspace>egin).
"""
control_chars = {
'\x08': 'b', # backspace
'\t': 't', # tab
'\n': 'n', # newline
'\r': 'r', # carriage return
'\f': 'f', # form feed
'\x07': 'a', # bell
'\x0b': 'v', # vertical tab
}

patterns = {char: [] for char in control_chars}

for token in math_token_to_rawid.keys():
if token.startswith('\\') and len(token) > 1:
letter = token[1:2] # First char after backslash
for ctrl_char, ctrl_letter in control_chars.items():
if letter == ctrl_letter:
# This token could be corrupted: \token -> <ctrl>oken
suffix = token[2:] # Everything after \X
patterns[ctrl_char].append((suffix, token))

# Sort by length (longest first) to avoid partial matches
for char in patterns:
patterns[char].sort(key=lambda x: len(x[0]), reverse=True)

return patterns

@classmethod
def _encode_math_greedy(
cls,
Expand Down Expand Up @@ -467,6 +281,9 @@ def __init__(
# Trie for math greedy match
self.trie = self._build_trie(self.math_token_to_rawid)

# Build escape fix patterns from vocab
self.latex_escape_patterns = self._build_escape_patterns(self.math_token_to_rawid)

# Tell HF about special tokens (metadata)
kwargs.setdefault("bos_token", bos_token)
kwargs.setdefault("eos_token", eos_token or "</S>")
Expand Down Expand Up @@ -526,6 +343,35 @@ def _encode_core(self, text: str) -> List[int]:
ids.extend(units)
return ids

def _fix_latex_escapes(self, text: str) -> str:
"""Fix improperly escaped LaTeX commands in decoded text.

Operates on the complete decoded string, replacing control character
sequences with their intended LaTeX commands based on vocab patterns.
"""
result = []
i = 0
while i < len(text):
char = text[i]
if char in self.latex_escape_patterns:
# Check if any pattern matches
matched = False
for pattern, replacement in self.latex_escape_patterns[char]:
if text[i+1:].startswith(pattern):
result.append(replacement)
i += 1 + len(pattern)
matched = True
break
if not matched:
# Not a LaTeX command, keep the control char as-is
result.append(char)
i += 1
else:
result.append(char)
i += 1

return ''.join(result)

def _decode_core(self, ids: Iterable[int]) -> str:
out: List[str] = []
buf: List[int] = []
Expand All @@ -545,7 +391,8 @@ def flush():
else:
buf.append(int(tid))
flush()
return "".join(out)
decoded = "".join(out)
return self._fix_latex_escapes(decoded)

# ---- Tokenizer interface ----
def _tokenize(self, text: str, **kwargs) -> List[str]:
Expand Down