# SPDX-FileCopyrightText: AISEC Pentesting Team
#
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import asyncio
import struct
from dataclasses import dataclass
from enum import IntEnum, unique
from pydantic import BaseModel
from gallia.log import get_logger
from gallia.transports.base import BaseTransport, TargetURI
@unique
class ProtocolVersions(IntEnum):
ISO_13400_2_2010 = 0x01
ISO_13400_2_2012 = 0x02
@unique
class RoutingActivationRequestTypes(IntEnum):
Default = 0x00
WWH_OBD = 0x01
CentralSecurity = 0xE0
@unique
class RoutingActivationResponseCodes(IntEnum):
UnknownSourceAddress = 0x00
NoRessources = 0x01
InvalidConnectionEntry = 0x02
AlreadyActived = 0x03
AuthenticationMissing = 0x04
ConfirmationRejected = 0x05
UnsupportedActivationType = 0x06
Success = 0x10
SuccessConfirmationRequired = 0x11
@unique
class PayloadTypes(IntEnum):
NegativeAcknowledge = 0x0000
VehicleIdentificationRequestMessage = 0x0002
VehicleIdentificationRequestMessageWithEID = 0x0003
VehicleIdentificationRequestMessageWithVIN = 0x0004
RoutingActivationRequest = 0x0005
RoutingActivationResponse = 0x0006
AliveCheckRequest = 0x0007
AliveCheckResponse = 0x0008
DoIPEntityStatusRequest = 0x4001
DoIPEntityStatusResponse = 0x4002
DiagnosticMessage = 0x8001
DiagnosticMessagePositiveAcknowledgement = 0x8002
DiagnosticMessageNegativeAcknowledgement = 0x8003
@unique
class DiagnosticMessagePositiveAckCodes(IntEnum):
Success = 0x00
@unique
class DiagnosticMessageNegativeAckCodes(IntEnum):
InvalidSourceAddress = 0x02
UnknownTargetAddress = 0x03
DiagnosticMessageTooLarge = 0x04
OutOfMemory = 0x05
TargetUnreachable = 0x06
UnknownNetwork = 0x07
TransportProtocolError = 0x08
@unique
class GenericHeaderNACKCodes(IntEnum):
IncorrectPatternFormat = 0x01
UnknownPayloadType = 0x02
MessageTooLarge = 0x03
OutOfMemory = 0x04
InvalidPayloadLength = 0x05
class TimingAndCommunicationParameters(IntEnum):
CtrlTimeout = 2000
AnnounceWait = 500
AnnounceInterval = 500
AnnounceNum = 3
DiagnosticMessageMessageAckTimeout = 50
DiagnosticMessageMessageTimeout = 2000
TCPGeneralInactivityTimeout = 5000
TCPInitalInactivityTimeout = 2000
TCPAliveCheckTimeout = 500
ProcessingTimeout = 2000
VecicleDiscoveryTimeout = 5000
@dataclass
class GenericHeader:
ProtocolVersion: ProtocolVersions
PayloadType: PayloadTypes
PayloadLength: int
PayloadTypeSpecificMessageContent: bytes
def pack(self) -> bytes:
return struct.pack(
"!BBHL",
self.ProtocolVersion,
self.ProtocolVersion ^ 0xFF,
self.PayloadType,
self.PayloadLength,
)
@classmethod
def unpack(cls, data: bytes) -> GenericHeader:
(
protocol_version,
inverse_protocol_version,
payload_type,
payload_length,
) = struct.unpack("!BBHL", data)
if protocol_version != inverse_protocol_version ^ 0xFF:
raise ValueError("inverse protocol_version is invalid")
return cls(
protocol_version,
PayloadTypes(payload_type),
payload_length,
b"",
)
@dataclass
class GenericHeaderNegativeAcknowledge:
GenericHeaderNACKCode: GenericHeaderNACKCodes
@dataclass
class RoutingActivationRequest:
SourceAddress: int
ActivationType: RoutingActivationRequestTypes
Reserved: int = 0x00000000 # Not used, default value.
# OEMReserved uint32
def pack(self) -> bytes:
return struct.pack(
"!HBI", self.SourceAddress, self.ActivationType, self.Reserved
)
@dataclass
class RoutingActivationResponse:
SourceAddress: int
TargetAddress: int
RoutingActivationResponseCode: RoutingActivationResponseCodes
Reserved: int = 0x00000000 # Not used, default value.
# OEMReserved uint32
@classmethod
def unpack(cls, data: bytes) -> RoutingActivationResponse:
(
source_address,
target_address,
routing_activation_response_code,
reserved,
) = struct.unpack("!HHBI", data)
if reserved != 0x00000000:
raise ValueError("reserved field contains data")
return cls(
source_address,
target_address,
routing_activation_response_code,
reserved,
)
@dataclass
class DiagnosticMessage:
SourceAddress: int
TargetAddress: int
UserData: bytes
def pack(self) -> bytes:
return (
struct.pack(
"!HH",
self.SourceAddress,
self.TargetAddress,
)
+ self.UserData
)
@classmethod
def unpack(cls, data: bytes) -> DiagnosticMessage:
source_address, target_address = struct.unpack("!HH", data[:4])
data = data[4:]
return cls(source_address, target_address, data)
@dataclass
class DiagnosticMessageAcknowledgement:
SourceAddress: int
TargetAddress: int
ACKCode: int
PreviousDiagnosticMessageData: bytes
def pack(self) -> bytes:
return (
struct.pack(
"!HHB",
self.SourceAddress,
self.TargetAddress,
self.ACKCode,
)
+ self.PreviousDiagnosticMessageData
)
class DiagnosticMessagePositiveAcknowledgement(DiagnosticMessageAcknowledgement):
ACKCode: DiagnosticMessagePositiveAckCodes
@classmethod
def unpack(cls, data: bytes) -> DiagnosticMessagePositiveAcknowledgement:
source_address, target_address, ack_code = struct.unpack("!HHB", data[:5])
prev_data = data[5:]
return cls(
source_address,
target_address,
DiagnosticMessagePositiveAckCodes(ack_code),
prev_data,
)
class DiagnosticMessageNegativeAcknowledgement(DiagnosticMessageAcknowledgement):
ACKCode: DiagnosticMessageNegativeAckCodes
@classmethod
def unpack(cls, data: bytes) -> DiagnosticMessageNegativeAcknowledgement:
source_address, target_address, ack_code = struct.unpack("!HHB", data[:5])
prev_data = data[5:]
return cls(
source_address,
target_address,
DiagnosticMessageNegativeAckCodes(ack_code),
prev_data,
)
@dataclass
class AliveCheckRequest:
pass
@dataclass
class AliveCheckResponse:
SourceAddress: int
def pack(self) -> bytes:
return struct.pack("!H", self.SourceAddress)
# Messages expected to be sent by the DoIP gateway.
DoIPInData = (
RoutingActivationResponse
| DiagnosticMessage
| DiagnosticMessagePositiveAcknowledgement
| DiagnosticMessageNegativeAcknowledgement
| AliveCheckRequest
)
# Messages expected to be sent by us.
DoIPOutData = RoutingActivationRequest | DiagnosticMessage | AliveCheckResponse
DoIPFrame = tuple[
GenericHeader,
DoIPInData | DoIPOutData,
]
DoIPDiagFrame = tuple[GenericHeader, DiagnosticMessage]
class DoIPConnection:
def __init__(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
src_addr: int,
target_addr: int,
):
self.logger = get_logger("doip")
self.reader = reader
self.writer = writer
self.src_addr = src_addr
self.target_addr = target_addr
self._read_queue: asyncio.Queue[DoIPFrame] = asyncio.Queue()
self._read_task = asyncio.create_task(self._read_worker())
self._closed = False
self._mutex = asyncio.Lock()
@classmethod
async def connect(
cls,
host: str,
port: int,
src_addr: int,
target_addr: int,
) -> DoIPConnection:
reader, writer = await asyncio.open_connection(host, port)
return cls(reader, writer, src_addr, target_addr)
async def _read_frame(self) -> DoIPFrame:
# Header is fixed size 8 byte.
hdr_buf = await self.reader.readexactly(8)
hdr = GenericHeader.unpack(hdr_buf)
payload_buf = await self.reader.readexactly(hdr.PayloadLength)
payload: DoIPInData
match hdr.PayloadType:
case PayloadTypes.RoutingActivationResponse:
payload = RoutingActivationResponse.unpack(payload_buf)
case PayloadTypes.DiagnosticMessagePositiveAcknowledgement:
payload = DiagnosticMessagePositiveAcknowledgement.unpack(payload_buf)
case PayloadTypes.DiagnosticMessageNegativeAcknowledgement:
payload = DiagnosticMessageNegativeAcknowledgement.unpack(payload_buf)
case PayloadTypes.DiagnosticMessage:
payload = DiagnosticMessage.unpack(payload_buf)
case PayloadTypes.AliveCheckRequest:
payload = AliveCheckRequest()
case _:
raise BrokenPipeError(
f"unexpected DoIP message: {hdr} {payload_buf.hex()}"
)
return hdr, payload
async def _read_worker(self) -> None:
try:
while True:
hdr, data = await self._read_frame()
if hdr.PayloadType == PayloadTypes.DiagnosticMessage and isinstance(
data, AliveCheckRequest
):
await self.write_alive_check_response()
continue
await self._read_queue.put((hdr, data))
except asyncio.CancelledError:
self.logger.debug("read worker cancelled")
async def read_frame_unsafe(self) -> DoIPFrame:
# Avoid waiting on the queue forever when
# the connection has been terminated.
if self._closed:
raise ConnectionError()
return await self._read_queue.get()
async def read_frame(self) -> DoIPFrame:
async with self._mutex:
return await self.read_frame_unsafe()
async def read_diag_request_raw(self) -> DoIPDiagFrame:
while True:
hdr, payload = await self.read_frame()
if not isinstance(payload, DiagnosticMessage):
raise BrokenPipeError(f"unexpected DoIP message: {hdr} {payload}")
if payload.SourceAddress != self.target_addr:
self.logger.warning(
f"unexpected DoIP src address: {payload.SourceAddress:#04x}"
)
continue
if payload.TargetAddress != self.src_addr:
self.logger.warning(
f"unexpected DoIP target address: {payload.TargetAddress:#04x}"
)
continue
return hdr, payload
async def read_diag_request(self) -> bytes:
_, payload = await self.read_diag_request_raw()
return payload.UserData
async def _read_ack(self, prev_data: bytes) -> None:
hdr, payload = await self.read_frame_unsafe()
if isinstance(payload, DiagnosticMessageNegativeAcknowledgement):
raise BrokenPipeError(f"request denied: {hdr} {payload}")
if not isinstance(payload, DiagnosticMessagePositiveAcknowledgement):
raise BrokenPipeError(
f"unexpected DoIP message: {hdr} {payload}, expected positive ACK"
)
if payload.SourceAddress != self.target_addr:
self.logger.warning(
f"ack: unexpected src_addr: {payload.SourceAddress:#04x}"
)
if payload.TargetAddress != self.src_addr:
self.logger.warning(
f"ack: unexpected dst_addr: {payload.TargetAddress:#04x}"
)
if (
len(payload.PreviousDiagnosticMessageData) > 0
and prev_data != payload.PreviousDiagnosticMessageData
):
self.logger.warning("ack: previous data differs from request")
self.logger.warning(
f"ack: got: {payload.PreviousDiagnosticMessageData.hex()} expected {prev_data.hex()}"
)
async def _read_routing_activation_response(self) -> None:
hdr, payload = await self.read_frame_unsafe()
if hdr.PayloadType != PayloadTypes.RoutingActivationResponse or not isinstance(
payload, RoutingActivationResponse
):
raise BrokenPipeError(f"unexpected DoIP message: {hdr} {payload}")
if (
payload.RoutingActivationResponseCode
!= RoutingActivationResponseCodes.Success
):
try:
code = RoutingActivationResponseCodes(
payload.RoutingActivationResponseCode
)
except ValueError as e:
raise ConnectionAbortedError(
f"unknown routing_activation_response_code: {payload.RoutingActivationResponseCode}"
) from e
raise ConnectionAbortedError(f"routing activation denied: {code}")
async def write_request_raw(self, hdr: GenericHeader, payload: DoIPOutData) -> None:
async with self._mutex:
buf = b""
buf += hdr.pack()
buf += payload.pack()
self.writer.write(buf)
await self.writer.drain()
self.logger.trace(f"hdr: {hdr}, payload: {payload}")
try:
match payload:
case DiagnosticMessage():
# Now an ACK message is expected.
await asyncio.wait_for(
self._read_ack(payload.UserData),
TimingAndCommunicationParameters.DiagnosticMessageMessageAckTimeout
/ 1000,
)
case RoutingActivationRequest():
await asyncio.wait_for(
self._read_routing_activation_response(),
TimingAndCommunicationParameters.DiagnosticMessageMessageAckTimeout
/ 1000,
)
except asyncio.TimeoutError as e:
await self.close()
raise BrokenPipeError() from e
async def write_diag_request(self, data: bytes) -> None:
hdr = GenericHeader(
ProtocolVersion=ProtocolVersions.ISO_13400_2_2012,
PayloadType=PayloadTypes.DiagnosticMessage,
PayloadLength=len(data) + 4,
PayloadTypeSpecificMessageContent=b"",
)
payload = DiagnosticMessage(
SourceAddress=self.src_addr,
TargetAddress=self.target_addr,
UserData=data,
)
await self.write_request_raw(hdr, payload)
async def write_routing_activation_request(
self,
activation_type: RoutingActivationRequestTypes,
) -> None:
hdr = GenericHeader(
ProtocolVersion=ProtocolVersions.ISO_13400_2_2012,
PayloadType=PayloadTypes.RoutingActivationRequest,
PayloadLength=7,
PayloadTypeSpecificMessageContent=b"",
)
payload = RoutingActivationRequest(
SourceAddress=self.src_addr,
ActivationType=activation_type,
Reserved=0x00,
)
await self.write_request_raw(hdr, payload)
async def write_alive_check_response(self) -> None:
hdr = GenericHeader(
ProtocolVersion=ProtocolVersions.ISO_13400_2_2012,
PayloadType=PayloadTypes.AliveCheckResponse,
PayloadLength=2,
PayloadTypeSpecificMessageContent=b"",
)
payload = AliveCheckResponse(
SourceAddress=self.src_addr,
)
await self.write_request_raw(hdr, payload)
async def close(self) -> None:
self._read_task.cancel()
self.writer.close()
await self.writer.wait_closed()
class DoIPConfig(BaseModel):
src_addr: int
target_addr: int
activation_type: int = RoutingActivationRequestTypes.WWH_OBD.value
[docs]class DoIPTransport(BaseTransport, scheme="doip"):
def __init__(
self, target: TargetURI, port: int, config: DoIPConfig, conn: DoIPConnection
):
super().__init__(target)
self.port = port
self.config = config
self._conn = conn
@staticmethod
async def _connect(
hostname: str,
port: int,
src_addr: int,
target_addr: int,
activation_type: int,
) -> DoIPConnection:
conn = await DoIPConnection.connect(
hostname,
port,
src_addr,
target_addr,
)
await conn.write_routing_activation_request(
RoutingActivationRequestTypes(activation_type)
)
return conn
[docs] @classmethod
async def connect(
cls, target: str | TargetURI, timeout: float | None = None
) -> DoIPTransport:
t = target if isinstance(target, TargetURI) else TargetURI(target)
cls.check_scheme(t)
if t.hostname is None:
raise ValueError("no hostname specified")
port = t.port if t.port is not None else 6801
config = DoIPConfig(**t.qs_flat)
conn = await asyncio.wait_for(
cls._connect(
t.hostname,
port,
config.src_addr,
config.target_addr,
config.activation_type,
),
timeout,
)
return cls(t, port, config, conn)
[docs] async def close(self) -> None:
self.is_closed = True
await self._conn.close()
[docs] async def read(
self,
timeout: float | None = None,
tags: list[str] | None = None,
) -> bytes:
data = await asyncio.wait_for(self._conn.read_diag_request(), timeout)
t = tags + ["read"] if tags is not None else ["read"]
self.logger.trace(data.hex(), extra={"tags": t})
return data
[docs] async def write(
self,
data: bytes,
timeout: float | None = None,
tags: list[str] | None = None,
) -> int:
await asyncio.wait_for(self._conn.write_diag_request(data), timeout)
t = tags + ["read"] if tags is not None else ["read"]
self.logger.trace(data.hex(), extra={"tags": t})
return len(data)