1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
// Copyright (c) The Diem Core Contributors
// SPDX-License-Identifier: Apache-2.0

use crate::utils::error_notes::ErrorNotes;
use anyhow::{bail, Result};
use async_trait::async_trait;
use bytes::{Bytes, BytesMut};
use std::convert::TryInto;
use tokio::io::{AsyncRead, AsyncReadExt};

#[async_trait]
pub trait ReadRecordBytes {
    async fn read_full_buf_or_none(&mut self, buf: &mut BytesMut) -> Result<()>;
    async fn read_record_bytes(&mut self) -> Result<Option<Bytes>>;
}

#[async_trait]
impl<R: AsyncRead + Send + Unpin> ReadRecordBytes for R {
    async fn read_full_buf_or_none(&mut self, buf: &mut BytesMut) -> Result<()> {
        assert_eq!(buf.len(), 0);
        let n_expected = buf.capacity();

        loop {
            let n_read = self.read_buf(buf).await.err_notes("")?;
            let n_read_total = buf.len();
            if n_read_total == n_expected {
                return Ok(());
            }
            if n_read == 0 {
                if n_read_total == 0 {
                    return Ok(());
                } else {
                    bail!(
                        "Hit EOF before filling the whole buffer, read {}, expected {}",
                        n_read_total,
                        n_expected
                    );
                }
            }
        }
    }

    async fn read_record_bytes(&mut self) -> Result<Option<Bytes>> {
        // read record size
        let mut size_buf = BytesMut::with_capacity(4);
        self.read_full_buf_or_none(&mut size_buf).await?;
        if size_buf.is_empty() {
            return Ok(None);
        }

        // empty record
        let record_size = u32::from_be_bytes(size_buf.as_ref().try_into()?) as usize;
        if record_size == 0 {
            return Ok(Some(Bytes::new()));
        }

        // read record
        let mut record_buf = BytesMut::with_capacity(record_size);
        self.read_full_buf_or_none(&mut record_buf).await?;
        if record_buf.is_empty() {
            bail!("Hit EOF when reading record.")
        }

        Ok(Some(record_buf.freeze()))
    }
}

#[cfg(test)]
mod tests {
    use crate::utils::read_record_bytes::ReadRecordBytes;
    use tokio::runtime::Runtime;

    #[test]
    fn test_read_record_bytes() {
        Runtime::new().unwrap().block_on(async {
            let data = b"abc";
            let size = (data.len() as u32).to_be_bytes();

            let mut good_record = size.to_vec();
            good_record.extend_from_slice(data);

            assert_eq!(
                good_record
                    .as_slice()
                    .read_record_bytes()
                    .await
                    .unwrap()
                    .unwrap(),
                &data[..],
            );

            let mut eof: &[u8] = &[];
            assert!(eof.read_record_bytes().await.unwrap().is_none());

            let mut empty = &0u32.to_be_bytes()[..];
            assert_eq!(empty.read_record_bytes().await.unwrap().unwrap(), &[][..]);

            let mut data_missing = &1u32.to_be_bytes()[..];
            assert!(data_missing.read_record_bytes().await.is_err());

            let mut bad_len = 10u32.to_be_bytes().to_vec();
            bad_len.pop();
            assert!(bad_len.as_slice().read_record_bytes().await.is_err());

            let mut bad_data = 10u32.to_be_bytes().to_vec();
            bad_data.push(0u8);
            assert!(bad_data.as_slice().read_record_bytes().await.is_err());
        })
    }
}