mirror of
https://github.com/shouptech/flask-tutorial.git
synced 2026-02-03 07:29:42 +00:00
Update unit tests for sqlalchemy
This commit is contained in:
parent
c3a3eb30e2
commit
4622b87a6f
6 changed files with 38 additions and 39 deletions
|
|
@ -7,11 +7,15 @@ from flask_sqlalchemy import SQLAlchemy
|
||||||
|
|
||||||
db = SQLAlchemy()
|
db = SQLAlchemy()
|
||||||
|
|
||||||
|
|
||||||
|
def init_db():
|
||||||
|
db.create_all()
|
||||||
|
|
||||||
@click.command('init-db')
|
@click.command('init-db')
|
||||||
@with_appcontext
|
@with_appcontext
|
||||||
def init_db_command():
|
def init_db_command():
|
||||||
"""Clear the existing data and create new tables."""
|
"""Clear the existing data and create new tables."""
|
||||||
db.create_all()
|
init_db()
|
||||||
click.echo('Initialized the database.')
|
click.echo('Initialized the database.')
|
||||||
|
|
||||||
def init_app(app):
|
def init_app(app):
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from werkzeug.security import generate_password_hash
|
||||||
|
|
||||||
from flaskr import create_app
|
from flaskr import create_app
|
||||||
from flaskr.db import get_db, init_db
|
from flaskr.db import init_db, db, User, Post
|
||||||
|
|
||||||
with open(os.path.join(os.path.dirname(__file__), 'data.sql'), 'rb') as f:
|
with open(os.path.join(os.path.dirname(__file__), 'data.sql'), 'rb') as f:
|
||||||
_data_sql = f.read().decode('utf8')
|
_data_sql = f.read().decode('utf8')
|
||||||
|
|
@ -15,12 +18,19 @@ def app():
|
||||||
|
|
||||||
app = create_app({
|
app = create_app({
|
||||||
'TESTING': True,
|
'TESTING': True,
|
||||||
'DATABASE': db_path,
|
'SECRET_KEY': 'testing',
|
||||||
|
'SQLALCHEMY_DATABASE_URI': 'sqlite:///{}'.format(db_path)
|
||||||
})
|
})
|
||||||
|
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
init_db()
|
init_db()
|
||||||
get_db().executescript(_data_sql)
|
db.session.add(User(username='test',
|
||||||
|
password=generate_password_hash('test')))
|
||||||
|
db.session.add(Post(title='test title',
|
||||||
|
body='test\nbody',
|
||||||
|
user_id=1,
|
||||||
|
created=datetime(year=2018,month=1,day=1)))
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
yield app
|
yield app
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
INSERT INTO user (username, password)
|
|
||||||
VALUES
|
|
||||||
('test', 'pbkdf2:sha256:50000$TCI4GzcX$0de171a4f4dac32e3364c7ddc7c14f3e2fa61f2d17574483f7ffbb431b4acb2f'),
|
|
||||||
('other', 'pbkdf2:sha256:50000$kJPKsz6N$d2d4784f1b030a9761f5ccaeeaca413f27f2ecb76d6168407af962ddce849f79');
|
|
||||||
|
|
||||||
INSERT INTO post (title, body, author_id, created)
|
|
||||||
VALUES
|
|
||||||
('test title', 'test' || x'0a' || 'body', 1, '2018-01-01 00:00:00');
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
from flask import g, session
|
from flask import g, session
|
||||||
from flaskr.db import get_db
|
from flaskr.db import User
|
||||||
|
|
||||||
|
|
||||||
def test_register(client, app):
|
def test_register(client, app):
|
||||||
|
|
@ -11,9 +11,7 @@ def test_register(client, app):
|
||||||
assert 'http://localhost/auth/login' == response.headers['Location']
|
assert 'http://localhost/auth/login' == response.headers['Location']
|
||||||
|
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
assert get_db().execute(
|
assert User.query.all() is not None
|
||||||
"select * from user where username = 'a'",
|
|
||||||
).fetchone() is not None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(('username', 'password', 'message'), (
|
@pytest.mark.parametrize(('username', 'password', 'message'), (
|
||||||
|
|
@ -36,7 +34,7 @@ def test_login(client, auth):
|
||||||
with client:
|
with client:
|
||||||
client.get('/')
|
client.get('/')
|
||||||
assert session['user_id'] == 1
|
assert session['user_id'] == 1
|
||||||
assert g.user['username'] == 'test'
|
assert g.user.username == 'test'
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(('username', 'password', 'message'), (
|
@pytest.mark.parametrize(('username', 'password', 'message'), (
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import pytest
|
import pytest
|
||||||
from flaskr.db import get_db
|
from flaskr.db import db, Post
|
||||||
|
|
||||||
|
|
||||||
def test_index(client, auth):
|
def test_index(client, auth):
|
||||||
|
|
@ -28,9 +28,10 @@ def test_login_required(client, path):
|
||||||
def test_author_required(app, client, auth):
|
def test_author_required(app, client, auth):
|
||||||
# change the post author to another user
|
# change the post author to another user
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
db = get_db()
|
post = Post.query.filter_by(user_id=1).first()
|
||||||
db.execute('UPDATE post SET author_id = 2 WHERE id = 1')
|
post.user_id = 2
|
||||||
db.commit()
|
db.session.add(post)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
auth.login()
|
auth.login()
|
||||||
# current user can't modify other user's post
|
# current user can't modify other user's post
|
||||||
|
|
@ -54,8 +55,7 @@ def test_create(client, auth, app):
|
||||||
client.post('/create', data={'title': 'created', 'body': ''})
|
client.post('/create', data={'title': 'created', 'body': ''})
|
||||||
|
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
db = get_db()
|
count = Post.query.count()
|
||||||
count = db.execute('SELECT COUNT(id) FROM post').fetchone()[0]
|
|
||||||
assert count == 2
|
assert count == 2
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -65,9 +65,8 @@ def test_update(client, auth, app):
|
||||||
client.post('/1/update', data={'title': 'updated', 'body': ''})
|
client.post('/1/update', data={'title': 'updated', 'body': ''})
|
||||||
|
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
db = get_db()
|
post = Post.query.filter_by(id=1).first()
|
||||||
post = db.execute('SELECT * FROM post WHERE id = 1').fetchone()
|
assert post.title == 'updated'
|
||||||
assert post['title'] == 'updated'
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('path', (
|
@pytest.mark.parametrize('path', (
|
||||||
|
|
@ -85,6 +84,5 @@ def test_delete(client, auth, app):
|
||||||
assert response.headers['Location'] == 'http://localhost/'
|
assert response.headers['Location'] == 'http://localhost/'
|
||||||
|
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
db = get_db()
|
post = Post.query.filter_by(id=1).first()
|
||||||
post = db.execute('SELECT * FROM post WHERE id = 1').fetchone()
|
|
||||||
assert post is None
|
assert post is None
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,8 @@
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from flaskr.db import get_db
|
|
||||||
|
|
||||||
|
from flaskr.db import Post, User
|
||||||
def test_get_close_db(app):
|
|
||||||
with app.app_context():
|
|
||||||
db = get_db()
|
|
||||||
assert db is get_db()
|
|
||||||
|
|
||||||
with pytest.raises(sqlite3.ProgrammingError) as e:
|
|
||||||
db.execute('SELECT 1')
|
|
||||||
|
|
||||||
assert 'closed' in str(e)
|
|
||||||
|
|
||||||
def test_init_db_command(runner, monkeypatch):
|
def test_init_db_command(runner, monkeypatch):
|
||||||
class Recorder(object):
|
class Recorder(object):
|
||||||
|
|
@ -25,3 +15,10 @@ def test_init_db_command(runner, monkeypatch):
|
||||||
result = runner.invoke(args=['init-db'])
|
result = runner.invoke(args=['init-db'])
|
||||||
assert 'Initialized' in result.output
|
assert 'Initialized' in result.output
|
||||||
assert Recorder.called
|
assert Recorder.called
|
||||||
|
|
||||||
|
def test_repr():
|
||||||
|
user = User(username='test')
|
||||||
|
assert 'test' in repr(user)
|
||||||
|
|
||||||
|
post = Post(title='test')
|
||||||
|
assert 'test' in repr(post)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue