diff --git a/surya/common/surya/processor/tokenizer.py b/surya/common/surya/processor/tokenizer.py index f102de91..9b98fa8c 100644 --- a/surya/common/surya/processor/tokenizer.py +++ b/surya/common/surya/processor/tokenizer.py @@ -25,225 +25,6 @@ def create_token_regex(tokens): return regex -class InnerOCRTokenizer: - def __init__( - self, - special_tokens: Dict[str, list] | None = None, - qwen_tokenizer: Qwen2OriginalTokenizer | None = None, - **kwargs, - ): - self.qwen_tokenizer = qwen_tokenizer - self.qwen_token_offset = len(qwen_tokenizer) - - all_special_tokens = special_tokens.get("all", []) - self.SPECIAL_TOKEN_MAPPING = {} - - idx = 0 - for tag in all_special_tokens: - if tag in self.SPECIAL_TOKEN_MAPPING: - continue - self.SPECIAL_TOKEN_MAPPING[tag] = ( - idx + self.qwen_token_offset - ) # Assign token ID - idx += 1 - - self.REVERSE_SPECIAL_TOKEN_MAPPING = { - v: k for k, v in self.SPECIAL_TOKEN_MAPPING.items() - } - self.SPECIAL_TOKEN_OFFSET = idx - self.FORMAT_TAG_PATTERN = create_token_regex(special_tokens["formatting"]) - self.MATH_TAG_PATTERN = create_token_regex(special_tokens["math_external"]) - self.LAYOUT_TAG_PATTERN = create_token_regex(special_tokens["layout"]) - self.TABLE_STRUCTURE_TAG_PATTERN = create_token_regex( - special_tokens["table_structure"] - ) - self.SYSTEM_TAG_PATTERN = create_token_regex(special_tokens.get("system", [])) - if not special_tokens.get("system", []): - logger.warning("Warning: No system tokens found in special_tokens") - - self.MATH_TAG_START = " List[int]: - tokens = [] - in_math = False - text = html.unescape(text) # Unescape html entities like < in equations - while text: - # Look for EOS, PAD, etc. tokens - match = self.SYSTEM_TAG_PATTERN.search(text) - if match: - tag = match.group(1) - tokens.append( - self.SPECIAL_TOKEN_MAPPING[tag] - ) # These are already offset - text = text[match.end() :] - continue - - # Look for layout tokens - match = self.LAYOUT_TAG_PATTERN.search(text) - if match: - tag = match.group(1) - tokens.append( - self.SPECIAL_TOKEN_MAPPING[tag] - ) # Layout tokens are already offset - text = text[match.end() :] - continue - - match = self.TABLE_STRUCTURE_TAG_PATTERN.search(text) - if match: - tag = match.group(1) - tokens.append(self.SPECIAL_TOKEN_MAPPING[tag]) - text = text[match.end() :] - continue - - # Check for math tags - match = self.MATH_TAG_PATTERN.search(text) - if match: - # We found a tag - tag = match.group(1) - if tag.startswith(self.MATH_TAG_START): - in_math = True - elif tag == self.MATH_END_TAG: - in_math = False - tokens.append( - self.SPECIAL_TOKEN_MAPPING[tag] # Special tokens are already offset - ) # Use special token ID - text = text[match.end() :] - continue - - # Tokenize math content with qwen2 tokenizer - if in_math: - # If we're in a math block, check to see if we have a special math tag in the text - math_end_position = text.find(self.MATH_END_TAG) - math_str = text[:math_end_position] # Gets the math content - tokens += self.qwen_tokenizer(math_str)["input_ids"] - text = text[math_end_position:] - continue - - # Check for formatting tags - match = self.FORMAT_TAG_PATTERN.search(text) - if match: - # We found a tag - tag = match.group(1) - tokens.append( - self.SPECIAL_TOKEN_MAPPING[tag] # Special tokens are already offset - ) # Use special token ID - text = text[match.end() :] - continue - - # General case, utf-16 tokenization - utf_16_tokens = self.text_to_utf16_numbers(text[0]) - tokens += [ - t + self.SPECIAL_TOKEN_OFFSET + self.qwen_token_offset - for t in utf_16_tokens - ] - text = text[1:] - - return tokens - - def text_to_utf16_numbers(self, text: str): - """Converts text to UTF-16 encoded numbers.""" - utf16_bytes = text.encode( - "utf-16le" - ) # Little-endian to simplify byte order handling - numbers = [] - - for i in range(0, len(utf16_bytes), 2): - # Combine two adjacent bytes into a single number - number = utf16_bytes[i] + (utf16_bytes[i + 1] << 8) - numbers.append(number) - - return numbers - - def utf16_numbers_to_text(self, numbers): - """Converts UTF-16 numbers back to text.""" - byte_array = bytearray() - for number in numbers: - byte_array.append(number & 0xFF) # Lower byte - byte_array.append((number >> 8) & 0xFF) # Upper byte - - try: - text = byte_array.decode("utf-16le", errors="ignore") - except Exception as e: - logger.warning(f"Error decoding utf16: {e}") - text = "" - - return text - - def __call__( - self, texts: Union[str, List[str]], **kwargs - ) -> Dict[str, List[List[int]]]: - """Tokenizes text and returns input IDs.""" - tokenized = [] - - if isinstance(texts, str): - texts = [texts] - - for text in texts: - tokens = self._tokenize(text) - tokenized.append(tokens) - - return {"input_ids": tokenized} - - def decode(self, token_ids, **kwargs): - """Decodes token IDs back to text.""" - if isinstance(token_ids, (np.ndarray, torch.Tensor)): - token_ids = token_ids.tolist() - - decoded_text = "" - token_buffer = [] - decode_qwen = [False] - - def decode_buffer(): - nonlocal decoded_text, token_buffer, decode_qwen - if token_buffer: - if decode_qwen[0]: - decoded_text += self.qwen_tokenizer.decode(token_buffer) - else: - token_buffer = [ - t - self.SPECIAL_TOKEN_OFFSET - self.qwen_token_offset - for t in token_buffer - ] - decoded_text += self.utf16_numbers_to_text(token_buffer) - - token_buffer = [] - decode_qwen[0] = False - - for t in token_ids: - if t < self.qwen_token_offset: - # This is for math tags - if token_buffer and token_buffer[-1] >= self.qwen_token_offset: - decode_buffer() - token_buffer.append(t) - decode_qwen[0] = True - elif t >= self.SPECIAL_TOKEN_OFFSET + self.qwen_token_offset: - if token_buffer and token_buffer[-1] < self.qwen_token_offset: - decode_buffer() - token_buffer.append(t) # We shift this down later on - decode_qwen[0] = False - elif t in self.REVERSE_SPECIAL_TOKEN_MAPPING: - decode_buffer() - decoded_text += self.REVERSE_SPECIAL_TOKEN_MAPPING[t] - decode_qwen[0] = False - else: - raise ValueError( - f'Unexpected token value while decoding, got "{t}" in token_ids {token_ids}' - ) - - # Detokenize remaining tokens - decode_buffer() - - return decoded_text - - class Qwen2Tokenizer(S3DownloaderMixin, Qwen2OriginalTokenizer): pass @@ -301,6 +82,39 @@ def _build_trie( node.id = tid return root + def _build_escape_patterns(self, math_token_to_rawid): + """Build pattern list from vocab commands that start with control characters. + + Scans the math vocab for LaTeX commands that could be corrupted by JSON + escape sequence interpretation (e.g., \\begin becomes egin). + """ + control_chars = { + '\x08': 'b', # backspace + '\t': 't', # tab + '\n': 'n', # newline + '\r': 'r', # carriage return + '\f': 'f', # form feed + '\x07': 'a', # bell + '\x0b': 'v', # vertical tab + } + + patterns = {char: [] for char in control_chars} + + for token in math_token_to_rawid.keys(): + if token.startswith('\\') and len(token) > 1: + letter = token[1:2] # First char after backslash + for ctrl_char, ctrl_letter in control_chars.items(): + if letter == ctrl_letter: + # This token could be corrupted: \token -> oken + suffix = token[2:] # Everything after \X + patterns[ctrl_char].append((suffix, token)) + + # Sort by length (longest first) to avoid partial matches + for char in patterns: + patterns[char].sort(key=lambda x: len(x[0]), reverse=True) + + return patterns + @classmethod def _encode_math_greedy( cls, @@ -467,6 +281,9 @@ def __init__( # Trie for math greedy match self.trie = self._build_trie(self.math_token_to_rawid) + # Build escape fix patterns from vocab + self.latex_escape_patterns = self._build_escape_patterns(self.math_token_to_rawid) + # Tell HF about special tokens (metadata) kwargs.setdefault("bos_token", bos_token) kwargs.setdefault("eos_token", eos_token or "") @@ -526,6 +343,35 @@ def _encode_core(self, text: str) -> List[int]: ids.extend(units) return ids + def _fix_latex_escapes(self, text: str) -> str: + """Fix improperly escaped LaTeX commands in decoded text. + + Operates on the complete decoded string, replacing control character + sequences with their intended LaTeX commands based on vocab patterns. + """ + result = [] + i = 0 + while i < len(text): + char = text[i] + if char in self.latex_escape_patterns: + # Check if any pattern matches + matched = False + for pattern, replacement in self.latex_escape_patterns[char]: + if text[i+1:].startswith(pattern): + result.append(replacement) + i += 1 + len(pattern) + matched = True + break + if not matched: + # Not a LaTeX command, keep the control char as-is + result.append(char) + i += 1 + else: + result.append(char) + i += 1 + + return ''.join(result) + def _decode_core(self, ids: Iterable[int]) -> str: out: List[str] = [] buf: List[int] = [] @@ -545,7 +391,8 @@ def flush(): else: buf.append(int(tid)) flush() - return "".join(out) + decoded = "".join(out) + return self._fix_latex_escapes(decoded) # ---- Tokenizer interface ---- def _tokenize(self, text: str, **kwargs) -> List[str]: