mirror of https://github.com/pallets/flask.git
Add support for function call fixing, add tests
Addresses #1135, some code cleanup and refactoring. Changes wrapper function which handles testing, further modularized code, added test to cover function call fixing, and fixed duplicate test function name.
This commit is contained in:
parent
9cbe83ef0d
commit
4cb311b945
|
|
@ -70,18 +70,15 @@ def fix_standard_imports(red):
|
|||
"""
|
||||
Handles import modification in the form:
|
||||
import flask.ext.foo" --> import flask_foo
|
||||
|
||||
Does not modify function calls elsewhere in the source outside of the
|
||||
original import statement.
|
||||
"""
|
||||
imports = red.find_all("ImportNode")
|
||||
for x, node in enumerate(imports):
|
||||
try:
|
||||
if (node.value[0].value == 'flask' and
|
||||
node.value[1].value == 'ext'):
|
||||
package = node.value[2].value
|
||||
name = node.names()[0]
|
||||
imports[x].replace("import flask_%s as %s" % (package, name))
|
||||
if (node.value[0].value[0].value == 'flask' and
|
||||
node.value[0].value[1].value == 'ext'):
|
||||
package = node.value[0].value[2]
|
||||
name = node.names()[0].split('.')[-1]
|
||||
node.replace("import flask_%s as %s" % (package, name))
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
|
|
@ -90,7 +87,7 @@ def fix_standard_imports(red):
|
|||
|
||||
def _get_modules(module):
|
||||
"""
|
||||
Takes a list of modules and converts into a string
|
||||
Takes a list of modules and converts into a string.
|
||||
|
||||
The module list can include parens, this function checks each element in
|
||||
the list, if there is a paren then it does not add a comma before the next
|
||||
|
|
@ -105,20 +102,46 @@ def _get_modules(module):
|
|||
return ''.join(modules_string)
|
||||
|
||||
|
||||
def fix_function_calls(red):
|
||||
"""
|
||||
Modifies function calls in the source to reflect import changes.
|
||||
|
||||
Searches the AST for AtomtrailerNodes and replaces them.
|
||||
"""
|
||||
atoms = red.find_all("Atomtrailers")
|
||||
for x, node in enumerate(atoms):
|
||||
try:
|
||||
if (node.value[0].value == 'flask' and
|
||||
node.value[1].value == 'ext'):
|
||||
node.replace("flask_foo%s" % node.value[3])
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
return red
|
||||
|
||||
|
||||
def check_user_input():
|
||||
"""Exits and gives error message if no argument is passed in the shell."""
|
||||
if len(sys.argv) < 2:
|
||||
sys.exit("No filename was included, please try again.")
|
||||
|
||||
|
||||
def fix(ast):
|
||||
def fix_tester(ast):
|
||||
"""Wrapper which allows for testing when not running from shell."""
|
||||
return fix_imports(ast).dumps()
|
||||
ast = fix_imports(ast)
|
||||
ast = fix_function_calls(ast)
|
||||
return ast.dumps()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def fix():
|
||||
"""Wrapper for user argument checking and import fixing."""
|
||||
check_user_input()
|
||||
input_file = sys.argv[1]
|
||||
ast = read_source(input_file)
|
||||
ast = fix_imports(ast)
|
||||
ast = fix_function_calls(ast)
|
||||
write_source(ast, input_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fix()
|
||||
|
|
|
|||
|
|
@ -8,19 +8,19 @@ import flaskext_migrate as migrate
|
|||
|
||||
def test_simple_from_import():
|
||||
red = RedBaron("from flask.ext import foo")
|
||||
output = migrate.fix(red)
|
||||
output = migrate.fix_tester(red)
|
||||
assert output == "import flask_foo as foo"
|
||||
|
||||
|
||||
def test_from_to_from_import():
|
||||
red = RedBaron("from flask.ext.foo import bar")
|
||||
output = migrate.fix(red)
|
||||
output = migrate.fix_tester(red)
|
||||
assert output == "from flask_foo import bar as bar"
|
||||
|
||||
|
||||
def test_multiple_import():
|
||||
red = RedBaron("from flask.ext.foo import bar, foobar, something")
|
||||
output = migrate.fix(red)
|
||||
output = migrate.fix_tester(red)
|
||||
assert output == "from flask_foo import bar, foobar, something"
|
||||
|
||||
|
||||
|
|
@ -29,23 +29,35 @@ def test_multiline_import():
|
|||
bar,\
|
||||
foobar,\
|
||||
something")
|
||||
output = migrate.fix(red)
|
||||
output = migrate.fix_tester(red)
|
||||
assert output == "from flask_foo import bar, foobar, something"
|
||||
|
||||
|
||||
def test_module_import():
|
||||
red = RedBaron("import flask.ext.foo")
|
||||
output = migrate.fix(red)
|
||||
assert output == "import flask_foo"
|
||||
output = migrate.fix_tester(red)
|
||||
assert output == "import flask_foo as foo"
|
||||
|
||||
|
||||
def test_module_import():
|
||||
def test_named_module_import():
|
||||
red = RedBaron("import flask.ext.foo as foobar")
|
||||
output = migrate.fix_tester(red)
|
||||
assert output == "import flask_foo as foobar"
|
||||
|
||||
|
||||
def test__named_from_import():
|
||||
red = RedBaron("from flask.ext.foo import bar as baz")
|
||||
output = migrate.fix(red)
|
||||
output = migrate.fix_tester(red)
|
||||
assert output == "from flask_foo import bar as baz"
|
||||
|
||||
|
||||
def test_parens_import():
|
||||
red = RedBaron("from flask.ext.foo import (bar, foo, foobar)")
|
||||
output = migrate.fix(red)
|
||||
output = migrate.fix_tester(red)
|
||||
assert output == "from flask_foo import (bar, foo, foobar)"
|
||||
|
||||
|
||||
def test_function_call_migration():
|
||||
red = RedBaron("flask.ext.foo(var)")
|
||||
output = migrate.fix_tester(red)
|
||||
assert output == "flask_foo(var)"
|
||||
|
|
|
|||
Loading…
Reference in New Issue