c# – 如何为PLINQ重新实现Where()运算符?

我想重写一个LINQ扩展来利用并行性.无论如何,我不知道从哪里开始.

作为一个教学示例,我想知道如何重写Where()运算符的实现,但是这可以在ParallelQuery上工作.

public static ParallelQuery<TSource> Where<TSource>(
   this ParallelQuery<TSource> source, 
   Func<TSource, bool> predicate)
{
    //implementation
}

有可能写:

someList.AsParallel().Where(...)

写一个串行执行的地方是微不足道的:

public static IEnumerable<TSource> Where<TSource>( 
    this IEnumerable<TSource> source, 
    Func<TSource, bool> predicate) 
{ 
    foreach (TSource item in source) 
    { 
        if (predicate(item)) 
        { 
            yield return item; 
        } 
    } 
}

我想到的是简单地将谓词包装在Parallel.ForEach()周围(并将结果推送到List / Array中),但我不认为这是要走的路.

我不知道写它是否微不足道(因此它可以作为SO答案)或者它非常复杂.如果是这样,从哪里开始提供一些提示也很好.可能有几种方法可以实现它,并且由于特定的优化可能会变得非常复杂,但是一个简单的实现可以正常工作(这意味着它提供了正确的结果并且比上面的非多线程实现更快)

正如Scott Chamberlain所建议的,这里是我要重写的LINQ方法的实现:

public static IEnumerable<TSource> WhereContains<TSource, TKey>(
     this IEnumerable<TSource> source, 
     IEnumerable<TKey> values,
     Func<TSource, TKey> keySelector)
{
    HashSet<TKey> elements = new HashSet<TKey>(values);

    foreach (TSource item in source)
    {
        if (elements.Contains(keySelector(item)))
        {
            yield return item;
        }
    }
}

最佳答案 你不能自己创建ParallelQuery< T>不幸的是,由于ParallelQuery< T>而基于类的事实.是公开的,它没有任何公共建设者.

您可以做的是使用现有的PLINQ基础设施来做您想做的事情.你真正想要做的就是做一个包含作为谓词的地方……就这样做.

public static ParallelQuery<TSource> WhereContains<TSource, TKey>(
    this ParallelQuery<TSource> source,
    IEnumerable<TKey> values,
    Func<TSource, TKey> keySelector)
{
    HashSet<TKey> elements = new HashSet<TKey>(values);

    return source.Where(item => elements.Contains(keySelector(item)));
}

这将并行执行Where子句,并且(虽然未记录)只要您没有执行任何写操作,包含就是线程安全的,并且因为您正在创建本地HashSet来执行查找,所以您不必担心写入发生.

这是一个示例项目,它向控制台打印出它正在处理的线程和项目,您可以看到它正在使用多个线程.

class Program
{
    static void Main(string[] args)
    {
        List<int> items = new List<int>(Enumerable.Range(0,100));

        int[] values = {5, 12, 25, 17, 0};

        Console.WriteLine("thread: {0}", Environment.CurrentManagedThreadId);

        var result = items.AsParallel().WhereContains(values, x=>x).ToList();

        Console.WriteLine("Done");
        Console.ReadLine();
    }
}

static class Extensions
{
    public static ParallelQuery<TSource> WhereContains<TSource, TKey>(
        this ParallelQuery<TSource> source,
        IEnumerable<TKey> values,
        Func<TSource, TKey> keySelector)
    {
        HashSet<TKey> elements = new HashSet<TKey>(values);

        return source.Where(item =>
        {
            Console.WriteLine("item:{0} thread: {1}", item, Environment.CurrentManagedThreadId);
            return elements.Contains(keySelector(item));
        });
    }
}
点赞