Skip to content

Commit

Permalink
Validate only task commands are run by executors
Browse files Browse the repository at this point in the history
  • Loading branch information
ashb committed Jun 8, 2020
1 parent b4b84a1 commit 1dda6fd
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 3 deletions.
3 changes: 3 additions & 0 deletions airflow/executors/celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@
@app.task
def execute_command(command_to_exec: CommandType) -> None:
"""Executes command."""
if command_to_exec[0:3] != ["airflow", "tasks", "run"]:
raise ValueError('The command must start with ["airflow", "tasks", "run"].')

log.info("Executing command in Celery: %s", command_to_exec)
env = os.environ.copy()
try:
Expand Down
3 changes: 3 additions & 0 deletions airflow/executors/dask_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def execute_async(self,
queue: Optional[str] = None,
executor_config: Optional[Any] = None) -> None:

if command[0:3] != ["airflow", "tasks", "run"]:
raise ValueError('The command must start with ["airflow", "tasks", "run"].')

def airflow_run():
return subprocess.check_call(command, close_fds=True)

Expand Down
3 changes: 3 additions & 0 deletions airflow/executors/kubernetes_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,9 @@ def run_next(self, next_job: KubernetesJobType) -> None:
if isinstance(command, str):
command = [command]

if command[0] != "airflow":
raise ValueError('The first element of command must be equal to "airflow".')

pod = PodGenerator.construct_pod(
namespace=self.namespace,
worker_uuid=self.worker_uuid,
Expand Down
4 changes: 4 additions & 0 deletions airflow/executors/local_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,10 @@ def execute_async(self, key: TaskInstanceKeyType,
"""Execute asynchronously."""
if not self.impl:
raise AirflowException(NOT_STARTED_MESSAGE)

if command[0:3] != ["airflow", "tasks", "run"]:
raise ValueError('The command must start with ["airflow", "tasks", "run"].')

self.impl.execute_async(key=key, command=command, queue=queue, executor_config=executor_config)

def sync(self) -> None:
Expand Down
4 changes: 4 additions & 0 deletions airflow/executors/sequential_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def execute_async(self,
command: CommandType,
queue: Optional[str] = None,
executor_config: Optional[Any] = None) -> None:

if command[0:3] != ["airflow", "tasks", "run"]:
raise ValueError('The command must start with ["airflow", "tasks", "run"].')

self.commands_to_run.append((key, command))

def sync(self) -> None:
Expand Down
30 changes: 27 additions & 3 deletions tests/executors/test_celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from parameterized import parameterized

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.executors import celery_executor
from airflow.executors.celery_executor import BulkStateFetcher
from airflow.models import TaskInstance
Expand Down Expand Up @@ -101,13 +102,18 @@ class TestCeleryExecutor(unittest.TestCase):
@pytest.mark.integration("rabbitmq")
@pytest.mark.backend("mysql", "postgres")
def test_celery_integration(self, broker_url):
with _prepare_app(broker_url) as app:
success_command = ['airflow', 'tasks', 'run', 'true', 'some_parameter']
fail_command = ['airflow', 'version']

def fake_execute_command(command):
if command != success_command:
raise AirflowException("fail")

with _prepare_app(broker_url, execute=fake_execute_command) as app:
executor = celery_executor.CeleryExecutor()
executor.start()

with start_worker(app=app, logfile=sys.stdout, loglevel='info'):
success_command = ['true', 'some_parameter']
fail_command = ['false', 'some_parameter']
execute_date = datetime.datetime.now()

cached_celery_backend = celery_executor.execute_command.backend
Expand Down Expand Up @@ -202,6 +208,24 @@ def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock
mock.call('executor.running_tasks', mock.ANY)]
mock_stats_gauge.assert_has_calls(calls)

@parameterized.expand((
[['true'], ValueError],
[['airflow', 'version'], ValueError],
[['airflow', 'tasks', 'run'], None]
))
@mock.patch('subprocess.check_call')
def test_command_validation(self, command, expected_exception, mock_check_call):
# Check that we validate _on the receiving_ side, not just sending side
if expected_exception:
with pytest.raises(expected_exception):
celery_executor.execute_command(command)
mock_check_call.assert_not_called()
else:
celery_executor.execute_command(command)
mock_check_call.assert_called_once_with(
command, stderr=mock.ANY, close_fds=mock.ANY, env=mock.ANY,
)


def test_operation_timeout_config():
assert celery_executor.OPERATION_TIMEOUT == 2
Expand Down

0 comments on commit 1dda6fd

Please sign in to comment.