mirror of https://github.com/pallets/flask.git
				
				
				
			
		
			
				
	
	
		
			148 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			148 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Python
		
	
	
	
| # Script which modifies source code away from the deprecated "flask.ext"
 | |
| # format. Does not yet fully support imports in the style:
 | |
| #
 | |
| # "import flask.ext.foo"
 | |
| #
 | |
| # these are converted to "import flask_foo" in the
 | |
| # main import statement, but does not handle function calls in the source.
 | |
| #
 | |
| # Run in the terminal by typing: `python flaskext_migrate.py <source_file.py>`
 | |
| #
 | |
| # Author: Keyan Pishdadian 2015
 | |
| 
 | |
| from redbaron import RedBaron
 | |
| import sys
 | |
| 
 | |
| 
 | |
| def read_source(input_file):
 | |
|     """Parses the input_file into a RedBaron FST."""
 | |
|     with open(input_file, "r") as source_code:
 | |
|         red = RedBaron(source_code.read())
 | |
|     return red
 | |
| 
 | |
| 
 | |
| def write_source(red, input_file):
 | |
|     """Overwrites the input_file once the FST has been modified."""
 | |
|     with open(input_file, "w") as source_code:
 | |
|         source_code.write(red.dumps())
 | |
| 
 | |
| 
 | |
| def fix_imports(red):
 | |
|     """Wrapper which fixes "from" style imports and then "import" style."""
 | |
|     red = fix_standard_imports(red)
 | |
|     red = fix_from_imports(red)
 | |
|     return red
 | |
| 
 | |
| 
 | |
| def fix_from_imports(red):
 | |
|     """
 | |
|     Converts "from" style imports to not use "flask.ext".
 | |
| 
 | |
|     Handles:
 | |
|     Case 1: from flask.ext.foo import bam --> from flask_foo import bam
 | |
|     Case 2: from flask.ext import foo --> import flask_foo as foo
 | |
|     """
 | |
|     from_imports = red.find_all("FromImport")
 | |
|     for x, node in enumerate(from_imports):
 | |
|         values = node.value
 | |
|         if (values[0].value == 'flask') and (values[1].value == 'ext'):
 | |
|             # Case 1
 | |
|             if len(node.value) == 3:
 | |
|                 package = values[2].value
 | |
|                 modules = node.modules()
 | |
|                 module_string = _get_modules(modules)
 | |
|                 if len(modules) > 1:
 | |
|                     node.replace("from flask_%s import %s"
 | |
|                                  % (package, module_string))
 | |
|                 else:
 | |
|                     name = node.names()[0]
 | |
|                     node.replace("from flask_%s import %s as %s"
 | |
|                                  % (package, module_string, name))
 | |
|             # Case 2
 | |
|             else:
 | |
|                 module = node.modules()[0]
 | |
|                 node.replace("import flask_%s as %s"
 | |
|                              % (module, module))
 | |
|     return red
 | |
| 
 | |
| 
 | |
| def fix_standard_imports(red):
 | |
|     """
 | |
|     Handles import modification in the form:
 | |
|     import flask.ext.foo" --> import flask_foo
 | |
|     """
 | |
|     imports = red.find_all("ImportNode")
 | |
|     for x, node in enumerate(imports):
 | |
|         try:
 | |
|             if (node.value[0].value[0].value == 'flask' and
 | |
|                node.value[0].value[1].value == 'ext'):
 | |
|                 package = node.value[0].value[2]
 | |
|                 name = node.names()[0].split('.')[-1]
 | |
|                 node.replace("import flask_%s as %s" % (package, name))
 | |
|         except IndexError:
 | |
|             pass
 | |
| 
 | |
|     return red
 | |
| 
 | |
| 
 | |
| def _get_modules(module):
 | |
|     """
 | |
|     Takes a list of modules and converts into a string.
 | |
| 
 | |
|     The module list can include parens, this function checks each element in
 | |
|     the list, if there is a paren then it does not add a comma before the next
 | |
|     element. Otherwise a comma and space is added. This is to preserve module
 | |
|     imports which are multi-line and/or occur within parens. While also not
 | |
|     affecting imports which are not enclosed.
 | |
|     """
 | |
|     modules_string = [cur + ', ' if cur.isalnum() and next.isalnum()
 | |
|                       else cur
 | |
|                       for (cur, next) in zip(module, module[1:]+[''])]
 | |
| 
 | |
|     return ''.join(modules_string)
 | |
| 
 | |
| 
 | |
| def fix_function_calls(red):
 | |
|     """
 | |
|     Modifies function calls in the source to reflect import changes.
 | |
| 
 | |
|     Searches the AST for AtomtrailerNodes and replaces them.
 | |
|     """
 | |
|     atoms = red.find_all("Atomtrailers")
 | |
|     for x, node in enumerate(atoms):
 | |
|         try:
 | |
|             if (node.value[0].value == 'flask' and
 | |
|                node.value[1].value == 'ext'):
 | |
|                 node.replace("flask_foo%s" % node.value[3])
 | |
|         except IndexError:
 | |
|             pass
 | |
| 
 | |
|     return red
 | |
| 
 | |
| 
 | |
| def check_user_input():
 | |
|     """Exits and gives error message if no argument is passed in the shell."""
 | |
|     if len(sys.argv) < 2:
 | |
|         sys.exit("No filename was included, please try again.")
 | |
| 
 | |
| 
 | |
| def fix_tester(ast):
 | |
|     """Wrapper which allows for testing when not running from shell."""
 | |
|     ast = fix_imports(ast)
 | |
|     ast = fix_function_calls(ast)
 | |
|     return ast.dumps()
 | |
| 
 | |
| 
 | |
| def fix():
 | |
|     """Wrapper for user argument checking and import fixing."""
 | |
|     check_user_input()
 | |
|     input_file = sys.argv[1]
 | |
|     ast = read_source(input_file)
 | |
|     ast = fix_imports(ast)
 | |
|     ast = fix_function_calls(ast)
 | |
|     write_source(ast, input_file)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     fix()
 |