tsunami-security-scanner/plugin_server/py/plugin_service_test.py

222 lines
7.1 KiB
Python

# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for plugin_service."""
import ipaddress
import time
from absl.testing import absltest
import grpc_testing
from google.protobuf import timestamp_pb2
import plugin_service
import tsunami_plugin
from common.net.http.http_client import HttpClient
from common.net.http.requests_http_client import RequestsHttpClientBuilder
from plugin.payload.payload_generator import PayloadGenerator
from plugin.payload.payload_secret_generator import PayloadSecretGenerator
from plugin.payload.payload_utility import get_parsed_payload
from plugin.tcs_client import TcsClient
import detection_pb2
import network_pb2
import network_service_pb2
import plugin_representation_pb2
import plugin_service_pb2
import reconnaissance_pb2
import vulnerability_pb2
_NetworkEndpoint = network_pb2.NetworkEndpoint
_NetworkService = network_service_pb2.NetworkService
_PluginInfo = plugin_representation_pb2.PluginInfo
_TargetInfo = reconnaissance_pb2.TargetInfo
_AddressFamily = network_pb2.AddressFamily
_ServiceDescriptor = plugin_service_pb2.DESCRIPTOR.services_by_name[
'PluginService'
]
_RunMethod = _ServiceDescriptor.methods_by_name['Run']
_ListPluginsMethod = _ServiceDescriptor.methods_by_name['ListPlugins']
MAX_WORKERS = 1
class PluginServiceTest(absltest.TestCase):
def setUp(self):
super().setUp()
# payload generator and client setup
self.request_client = RequestsHttpClientBuilder().build()
psg = PayloadSecretGenerator()
callback_client = TcsClient(
'127.0.0.1', 8000, 'http://127.0.0.1:8000/test', self.request_client
)
self.payload_generator = PayloadGenerator(
psg, get_parsed_payload(), callback_client
)
self.test_plugin = FakeVulnDetector(
self.request_client, self.payload_generator
)
self._time = grpc_testing.strict_fake_time(time.time())
self._server = grpc_testing.server_from_dictionary(
{
_ServiceDescriptor: plugin_service.PluginServiceServicer(
py_plugins=[self.test_plugin], max_workers=MAX_WORKERS
),
},
self._time,
)
self._channel = grpc_testing.channel(
plugin_service_pb2.DESCRIPTOR.services_by_name.values(), self._time
)
def tearDown(self):
self._channel.close()
super().tearDown()
def test_run_plugins_registered_returns_valid_response(self):
plugin_to_test = FakeVulnDetector(
self.request_client, self.payload_generator
)
endpoint = _build_network_endpoint('1.1.1.1', 80)
service = _NetworkService(
network_endpoint=endpoint,
transport_protocol=network_pb2.TCP,
service_name='http',
)
target = _TargetInfo(network_endpoints=[endpoint])
services = [service]
request = plugin_service_pb2.RunRequest(
target=target,
plugins=[
plugin_service_pb2.MatchedPlugin(
services=services, plugin=plugin_to_test.GetPluginDefinition()
)
],
)
rpc = self._server.invoke_unary_unary(_RunMethod, (), request, None)
response, _, _, _ = rpc.termination()
self.assertLen(response.reports.detection_reports, 1)
self.assertEqual(
plugin_to_test._BuildFakeDetectionReport(
target=target, network_service=services[0]
),
response.reports.detection_reports[0],
)
def test_run_no_plugins_registered_returns_empty_response(self):
endpoint = _build_network_endpoint('1.1.1.1', 80)
target = _TargetInfo(network_endpoints=[endpoint])
request = plugin_service_pb2.RunRequest(target=target, plugins=[])
rpc = self._server.invoke_unary_unary(_RunMethod, (), request, None)
response, _, _, _ = rpc.termination()
self.assertEmpty(response.reports.detection_reports)
def test_list_plugins_plugins_registered_returns_valid_response(self):
request = plugin_service.ListPluginsRequest()
rpc = self._server.invoke_unary_unary(_ListPluginsMethod, (), request, None)
response, _, _, _ = rpc.termination()
self.assertEqual(
plugin_service.ListPluginsResponse(
plugins=[self.test_plugin.GetPluginDefinition()]
),
response,
)
def _build_network_endpoint(ip: str, port: int) -> _NetworkEndpoint:
return _NetworkEndpoint(
type=_NetworkEndpoint.IP,
ip_address=network_pb2.IpAddress(address_family=_get_address_family(ip)),
port=network_pb2.Port(port_number=port),
)
def _get_address_family(ip: str) -> _AddressFamily:
inet_addr = ipaddress.ip_address(ip)
if inet_addr.version == 4:
return _AddressFamily.IPV4
elif inet_addr.version == 6:
return _AddressFamily.IPV6
else:
raise ValueError("Unknown IP address family for IP '%s'" % ip)
class FakeVulnDetector(tsunami_plugin.VulnDetector):
"""Fake Vulnerability detector class for testing only."""
def __init__(
self,
http_client: HttpClient,
payload_generator: PayloadGenerator,
):
self.http_client = http_client
self.payload_generator = payload_generator
def GetAdvisories(self) -> list[vulnerability_pb2.Vulnerability]:
"""Returns the advisories for this plugin."""
return [
vulnerability_pb2.Vulnerability(
main_id=vulnerability_pb2.VulnerabilityId(
publisher='GOOGLE', value='FakeVuln1'
),
severity=vulnerability_pb2.CRITICAL,
title='FakeTitle1',
description='FakeDescription1',
),
]
def GetPluginDefinition(self):
return tsunami_plugin.PluginDefinition(
info=_PluginInfo(
type=_PluginInfo.VULN_DETECTION,
name='fake',
version='v0.1',
description='fake description',
author='fake author',
),
target_service_name=plugin_representation_pb2.TargetServiceName(
value=['fake service']
),
target_software=plugin_representation_pb2.TargetSoftware(
name='fake software'
),
for_web_service=False,
)
def Detect(self, target, matched_services):
return detection_pb2.DetectionReportList(
detection_reports=[
self._BuildFakeDetectionReport(target, matched_services[0])
]
)
def _BuildFakeDetectionReport(self, target, network_service):
return detection_pb2.DetectionReport(
target_info=target,
network_service=network_service,
detection_timestamp=timestamp_pb2.Timestamp(nanos=1234567890),
detection_status=detection_pb2.VULNERABILITY_VERIFIED,
vulnerability=self.GetAdvisories()[0],
)
# TODO(b/239628051): Add a failed VulnDetector class to test failed cases.
if __name__ == '__main__':
absltest.main()