1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
use proc_macro2::TokenStream;
use quote::{format_ident, quote, ToTokens};
use syn::{parse::Parser, punctuated::Punctuated, Expr, Index, Token};

/// The `stream_select!` macro.
pub(crate) fn stream_select(input: TokenStream) -> Result<TokenStream, syn::Error> {
    let args = Punctuated::<Expr, Token![,]>::parse_terminated.parse2(input)?;
    if args.len() < 2 {
        return Ok(quote! {
           compile_error!("stream select macro needs at least two arguments.")
        });
    }
    let generic_idents = (0..args.len()).map(|i| format_ident!("_{}", i)).collect::<Vec<_>>();
    let field_idents = (0..args.len()).map(|i| format_ident!("__{}", i)).collect::<Vec<_>>();
    let field_idents_2 = (0..args.len()).map(|i| format_ident!("___{}", i)).collect::<Vec<_>>();
    let field_indices = (0..args.len()).map(Index::from).collect::<Vec<_>>();
    let args = args.iter().map(|e| e.to_token_stream());

    Ok(quote! {
        {
            #[derive(Debug)]
            struct StreamSelect<#(#generic_idents),*> (#(Option<#generic_idents>),*);

            enum StreamEnum<#(#generic_idents),*> {
                #(
                    #generic_idents(#generic_idents)
                ),*,
                None,
            }

            impl<ITEM, #(#generic_idents),*> __futures_crate::stream::Stream for StreamEnum<#(#generic_idents),*>
            where #(#generic_idents: __futures_crate::stream::Stream<Item=ITEM> + ::std::marker::Unpin,)*
            {
                type Item = ITEM;

                fn poll_next(mut self: ::std::pin::Pin<&mut Self>, cx: &mut __futures_crate::task::Context<'_>) -> __futures_crate::task::Poll<Option<Self::Item>> {
                    match self.get_mut() {
                        #(
                            Self::#generic_idents(#generic_idents) => ::std::pin::Pin::new(#generic_idents).poll_next(cx)
                        ),*,
                        Self::None => panic!("StreamEnum::None should never be polled!"),
                    }
                }
            }

            impl<ITEM, #(#generic_idents),*> __futures_crate::stream::Stream for StreamSelect<#(#generic_idents),*>
            where #(#generic_idents: __futures_crate::stream::Stream<Item=ITEM> + ::std::marker::Unpin,)*
            {
                type Item = ITEM;

                fn poll_next(mut self: ::std::pin::Pin<&mut Self>, cx: &mut __futures_crate::task::Context<'_>) -> __futures_crate::task::Poll<Option<Self::Item>> {
                    let Self(#(ref mut #field_idents),*) = self.get_mut();
                    #(
                        let mut #field_idents_2 = false;
                    )*
                    let mut any_pending = false;
                    {
                        let mut stream_array = [#(#field_idents.as_mut().map(|f| StreamEnum::#generic_idents(f)).unwrap_or(StreamEnum::None)),*];
                        __futures_crate::async_await::shuffle(&mut stream_array);

                        for mut s in stream_array {
                            if let StreamEnum::None = s {
                                continue;
                            } else {
                                match __futures_crate::stream::Stream::poll_next(::std::pin::Pin::new(&mut s), cx) {
                                    r @ __futures_crate::task::Poll::Ready(Some(_)) => {
                                        return r;
                                    },
                                    __futures_crate::task::Poll::Pending => {
                                        any_pending = true;
                                    },
                                    __futures_crate::task::Poll::Ready(None) => {
                                        match s {
                                            #(
                                                StreamEnum::#generic_idents(_) => { #field_idents_2 = true; }
                                            ),*,
                                            StreamEnum::None => panic!("StreamEnum::None should never be polled!"),
                                        }
                                    },
                                }
                            }
                        }
                    }
                    #(
                        if #field_idents_2 {
                            *#field_idents = None;
                        }
                    )*
                    if any_pending {
                        __futures_crate::task::Poll::Pending
                    } else {
                        __futures_crate::task::Poll::Ready(None)
                    }
                }

                fn size_hint(&self) -> (usize, Option<usize>) {
                    let mut s = (0, Some(0));
                    #(
                        if let Some(new_hint) = self.#field_indices.as_ref().map(|s| s.size_hint()) {
                            s.0 += new_hint.0;
                            // We can change this out for `.zip` when the MSRV is 1.46.0 or higher.
                            s.1 = s.1.and_then(|a| new_hint.1.map(|b| a + b));
                        }
                    )*
                    s
                }
            }

            StreamSelect(#(Some(#args)),*)

        }
    })
}