class BatchTextIteratorStreamer(TextIteratorStreamer): def __init__(self, batch_size:int, tokenizer: "AutoTokenizer", skip_prompt: bool = False, timeout: Optional[float] = None, **decode_kwargs): super().__init__(tokenizer, skip_prompt, timeout, **decode_kwargs) self.batch_size = batch_size self.token_cache = [[] for _ in range(batch_size)] self.print_len = [0 for _ in range(batch_size)] self.generate_exception = None def put(self, value): if len(value.shape) != 2: value = torch.reshape(value, (self.batch_size, value.shape[0] // self.batch_size)) if self.skip_prompt and self.next_tokens_are_prompt: self.next_tokens_are_prompt = False return printable_texts = list() for idx in range(self.batch_size): self.token_cache[idx].extend(value[idx].tolist()) text = self.tokenizer.decode(self.token_cache[idx], **self.decode_kwargs) if text.endswith("\n"): printable_text = text[self.print_len[idx] :] self.token_cache[idx] = [] self.print_len[idx] = 0 # If the last token is a CJK character, we print the characters. elif len(text) > 0 and self._is_chinese_char(ord(text[-1])): printable_text = text[self.print_len[idx] :] self.print_len[idx] += len(printable_text) else: printable_text = text[self.print_len[idx] : text.rfind(" ") + 1] self.print_len[idx] += len(printable_text) printable_texts.append(printable_text) self.on_finalized_text(printable_texts) def end(self): printable_texts = list() for idx in range(self.batch_size): if len(self.token_cache[idx]) > 0: text = self.tokenizer.decode(self.token_cache[idx], **self.decode_kwargs) printable_text = text[self.print_len[idx] :] self.token_cache[idx] = [] self.print_len[idx] = 0 else: printable_text = "" printable_texts.append(printable_text) self.next_tokens_are_prompt = True self.on_finalized_text(printable_texts, stream_end=True) def on_finalized_text(self, texts: List[str], stream_end: bool = False): self.text_queue.put(texts, timeout=self.timeout) if stream_end: self.text_queue.put(self.stop_signal, timeout=self.timeout)
transformers
在线文档
语言
业务
神经网络