1 # Copyright (c) Microsoft Corporation. All rights reserved.
2 # Licensed under the MIT License.
13 class Visitor(ast.NodeVisitor):
14 def __init__(self, lines):
16 self.line_numbers_with_nodes = set()
17 self.line_numbers_with_statements = []
19 def generic_visit(self, node):
20 if hasattr(node, 'col_offset') and hasattr(node, 'lineno') and node.col_offset == 0:
21 self.line_numbers_with_nodes.add(node.lineno)
22 if isinstance(node, ast.stmt):
23 self.line_numbers_with_statements.append(node.lineno)
25 ast.NodeVisitor.generic_visit(self, node)
28 def _tokenize(source):
29 """Tokenize Python source code."""
30 # Using an undocumented API as the documented one in Python 2.7 does not work as needed
32 if sys.version_info < (3,) and isinstance(source, str):
33 source = source.decode()
34 return tokenize.generate_tokens(io.StringIO(source).readline)
37 def _indent_size(line):
38 for index, char in enumerate(line):
39 if not char.isspace():
43 def _get_global_statement_blocks(source, lines):
44 """Return a list of all global statement blocks.
46 The list comprises of 3-item tuples that contain the starting line number,
47 ending line number and whether the statement is a single line.
50 tree = ast.parse(source)
51 visitor = Visitor(lines)
55 for index, line_number in enumerate(visitor.line_numbers_with_statements):
56 remaining_line_numbers = visitor.line_numbers_with_statements[index+1:]
57 end_line_number = len(lines) if len(remaining_line_numbers) == 0 else min(remaining_line_numbers) - 1
58 current_statement_is_oneline = line_number == end_line_number
60 if len(statement_ranges) == 0:
61 statement_ranges.append((line_number, end_line_number, current_statement_is_oneline))
64 previous_statement = statement_ranges[-1]
65 previous_statement_is_oneline = previous_statement[2]
66 if previous_statement_is_oneline and current_statement_is_oneline:
67 statement_ranges[-1] = previous_statement[0], end_line_number, True
69 statement_ranges.append((line_number, end_line_number, current_statement_is_oneline))
71 return statement_ranges
74 def normalize_lines(source):
75 """Normalize blank lines for sending to the terminal.
77 Blank lines within a statement block are removed to prevent the REPL
78 from thinking the block is finished. Newlines are added to separate
79 top-level statements so that the REPL does not think there is a syntax
83 lines = source.splitlines(False)
84 # If we have two blank lines, then add two blank lines.
85 # Do not trim the spaces, if we have blank lines with spaces, its possible
86 # we have indented code.
87 if (len(lines) > 1 and len(''.join(lines[-2:])) == 0) \
88 or source.endswith(('\n\n', '\r\n\r\n')):
89 trailing_newline = '\n' * 2
90 # Find out if we have any trailing blank lines
91 elif len(lines[-1].strip()) == 0 or source.endswith(('\n', '\r\n')):
92 trailing_newline = '\n'
96 # Step 1: Remove empty lines.
97 tokens = _tokenize(source)
98 newlines_indexes_to_remove = (spos[0] for (toknum, tokval, spos, epos, line) in tokens
99 if len(line.strip()) == 0
100 and token.tok_name[toknum] == 'NL'
101 and spos[0] == epos[0])
103 for line_number in reversed(list(newlines_indexes_to_remove)):
104 del lines[line_number-1]
106 # Step 2: Add blank lines between each global statement block.
107 # A consequtive single lines blocks of code will be treated as a single statement,
108 # just to ensure we do not unnecessarily add too many blank lines.
109 source = '\n'.join(lines)
110 tokens = _tokenize(source)
111 dedent_indexes = (spos[0] for (toknum, tokval, spos, epos, line) in tokens
112 if toknum == token.DEDENT and _indent_size(line) == 0)
114 global_statement_ranges = _get_global_statement_blocks(source, lines)
115 start_positions = map(operator.itemgetter(0), reversed(global_statement_ranges))
116 for line_number in filter(lambda x: x > 1, start_positions):
117 lines.insert(line_number-1, '')
119 sys.stdout.write('\n'.join(lines) + trailing_newline)
123 if __name__ == '__main__':
124 contents = sys.argv[1]
126 default_encoding = sys.getdefaultencoding()
127 encoded_contents = contents.encode(default_encoding, 'surrogateescape')
128 contents = encoded_contents.decode(default_encoding, 'replace')
129 except (UnicodeError, LookupError):
131 if isinstance(contents, bytes):
132 contents = contents.decode('utf8')
133 normalize_lines(contents)